Skip to content

ENH: Array API dispatching (backport #2096)#2545

Merged
napetrov merged 1 commit into
rls/2025.5.0-rlsfrom
mergify/bp/rls/2025.5.0-rls/pr-2096
Jun 13, 2025
Merged

ENH: Array API dispatching (backport #2096)#2545
napetrov merged 1 commit into
rls/2025.5.0-rlsfrom
mergify/bp/rls/2025.5.0-rls/pr-2096

Conversation

@mergify
Copy link
Copy Markdown

@mergify mergify Bot commented Jun 13, 2025

Description

This PR enables the interface for scikit-learn-intelex estimators to dispatch Array API inputs to oneDAL. This removes the need to move data to host and back in the case of on-device data (see the need behind use_raw_input) in properly enabled estimators. This PR will enable array_api_dispatching for sklearn >=1.2 . As it is experimental in sklearn versions, the direct use of sklearn functions may fail (see Ensemble estimators' apply). Scikit-learn 1.2 does not have functional array checking with array_api_dispatch, so sklearnex/onedal array_api support with zero-copy will begin in earnest with sklearn 1.3. Enablement of the zero-copy will be for sklearn versions >=1.3 going forward.

To accomplish this, the following related changes had to be made

  • A new object oneDALEstimator is created in base.py which combines the capabilities of PatchableEstimator and ExtensionEstimator. All sklearnex estimators must inherit this before sklearn's BaseEstimator in the mro. This is enforced by testing in test_common.py. This is needed in order to provide a tag for signifying an Array_API supported sklearnex estimator without cluttering the inheritance of the estimators. Note: as a consequence the HTML docs support is limited to the range of sklearn versions which are also supported (>=1.4), and is the logic is rewritten to minimize maintenance. This required changes to TSNE as well.
  • A function/class get_tags is introduced to check for sklearn's tags and to minimize maintenance. This duck-types the capability of recent (>=1.6) changes to the tag system to older sklearn versions, simplifying the code. This is needed to check for array_api support and onedal array_api support via the tag system with minimal maintenance for older versions.
  • All sklearnex estimators are modified to inherit the oneDALEstimator
  • The code for sklearnex/_device_offload.py is refactored. This new logic changes the cumbersome passing of strings from _get_backend into a boolean flag. A second checkover of inputs was removed from that function.
  • The core function dispatch is extensively refactored in order to check for array_api support. The logic there is of particular note due to the increase in branching. It attempts at all times to minimize the checking or movement of data unless necessary, therefore requiring checking array_api support first (as this is data-movement free).
  • tests test_standard_estimator_patching and test_special_estimator_patching now reference a combined _check_estimator_patching function to reduce code duplication. test_standard_estimator_patching now includes array_api dispatching for the array-api-strict data type. Changes coming in dpnp 0.17 and dpctl 0.19 will make array_api_dispatch a necessity, requiring that for dataframe support for non-numpy, non-pandas inputs going forward. This is a significant task and must be done separately and also meaning further modification to the test suite.
  • Standard estimator testing with array_api_strict inputs are now done with config_context(array_api_dispatch=True) which verifies proper operation of sklearnex patched methods. This required changes to KMeans.transform, IncrementalEmpiricalCovariance.mahalanobis, IncrementalEmpiricalCovariance.score, and KMeans.fit_transform beyond the changes to the standard dispatcher.
  • a new test test_fallback_to_host is added due to the complexity of dispatch. This validates that the various scenarios with offloading + fallbacks place the data properly while still preserving the queue

The following bugs needed to be solved to get things working properly:

  • is_dpctl_device_available calls from test_memory_usage need to be a list input, not a string.
  • checks on transform_output used in _device_offload support_input_format and wrap_output_data need to check for default and not None, the array_api logic should only apply for non-numpy inputs, but should be able to convert numpy results. previously this logic was always skipped
  • support for allow_fallback_to_host is corrected for when used in conjunction with target_offload
  • _asarray does not work for older numpy versions, as arrays do not have the __array_namespace__ attribute. Following the guideline from the array_api standard, a check for the buffer protocol is added in line to match: https://github.com/data-apis/array-api-strict/blob/25cc3d7ff0069b222d228d380e6d95bbf9a5dbcf/array_api_strict/_creation_functions.py#L50
  • test_run_to_run_stability includes a branch for .dtype-having objects which will then flatten, as array_api_strict arrays are very picky with the use of __iter__. This will simplify the analysis of 2d-array types, which have otherwise split them unnecessarily into 1d arrays for individual analysis.
  • Add missing doclink support for ElasticNet and Lasso

Additional cleanup changes

  • use of the import of inspect for just inspect.signature in test_patching is removed
  • An unnecessary base class is removed from LogisticRegression, as cleanup to the inheritance is undertaken in this PR.
  • creation of get_tags in sklearnex.utils makes the function get_array_api_support_tag unnecessary in sklearnex/_device_offload and is removed. It also attempts to match sklearn public API calls for simplicity in development.
  • Inheritance in BaseSVM of BaseEstimator is removed, as this is inherited from the various sklearn SVM algorithms in inheritance
  • The logic for use_raw_input in dispatch is inverted and placed at the beginning for simplicity in analysis at no cost in computational time.

As none of the onedal/ estimators are array_api enabled yet, the code is not fully tested, and additional PRs will be necessary to properly test and verify array_api support. These are not simple PRs and are to be included later before any estimator becomes array_api enabled. The initial testing should be sufficient for simple array_api support verification, and prevent non-functional changes for array_api support from entering the codebase.

Follow up work must be made for support extracting queues from array_api inputs which comply with the new sycl_queue management system. As of this PR, it will not properly understand array_api devices. This will modify a separate file and must be a follow-on PR.

It is likely that this PR needs to be separated based on individual themes contained within that were necessary to enable the dispatching. Help from the reviewers is greatly desired (and can be based solely on the description).

NOTE: The expansion of testing for array_api to use config_context(array_api_dispatch=True) has uncovered some bugs which were heretofore unnoticed. 1st the array_api dispatching of IncrementalEmpiricalCovariance.score for sklearn <1.3 is not possible, and has been deselected. 2nd PCA.score_sample and PCA.score do not currently work as the return values from PCA.fit (._precision and ._mean) need to be the same type as the input X values. As of now, the return values and estimator attributes are not properly checked for type matching, and conversions to non-numpy types from oneDAL is inconsistently implemented. This is a significant undertaking and must be done as a follow-up with the individual array_api support in oneDAL while maintaining quality code. This will also require the completion of __dlpack__ support to oneDAL table types, such that from_table can return all array_api types.


Checklist to comply with before moving PR from draft:

PR completeness and readability

  • I have reviewed my changes thoroughly before submitting this pull request.
  • I have commented my code, particularly in hard-to-understand areas.
  • I have updated the documentation to reflect the changes or created a separate PR with update and provided its number in the description, if necessary.
  • Git commit message contains an appropriate signed-off-by string (see CONTRIBUTING.md for details).
  • I have added a respective label(s) to PR if I have a permission for that.
  • I have resolved any merge conflicts that might occur with the base branch.

Testing

  • I have run it locally and tested the changes extensively.
  • All CI jobs are green or I have provided justification why they aren't.
  • I have extended testing suite if new functionality was introduced in this PR.

Performance

  • I have measured performance for affected algorithms using scikit-learn_bench and provided at least summary table with measured data, if performance change is expected.
  • I have provided justification why performance has changed or why changes are not expected.
  • I have provided justification why quality metrics have changed or why changes are not expected.
  • I have extended benchmarking suite and provided corresponding scikit-learn_bench PR if new measurable functionality was introduced in this PR.

This is an automatic backport of pull request #2096 done by [Mergify](https://mergify.com).

* ENH: array api dispatching

added array-api-compat to test env

* Deselect some scikit-learn Array API tests

* deselect more tests

* deselect more tests

* disabled tests for

* fix the deselection comment

* disabled test for Ridge regression

* Disabled tests and added comment

* ENH: Array API dispatching

* Revert adding dpctl into Array PI conformance testing

added versioning for the get_nnamespace

* minor refactoring onedal _array_api

* add tests

* addressed memory usage tests

* Address some array api test fails

* linting

* addressed test_get_namespace

* adding test case for validate_data check with Array API inputs

* minor refactoring

* addressed test_patch_map_match fail

* Added docstrings for get_namespace

* docstrings for Array API tests

* updated minimal scikit-learn version for Array API dispatching

* updated minimal scikit-learn version for Array API dispatching in _device_offload.py _array_api.py

* fix test test_get_namespace_with_config_context

* refactor onedal/datatypes/_data_conversion.py

* correction for array api

* Update conftest.py

* introduce tags

* fix imports

* see if this works

* really lazy logic introduction

* introduce IntelEstimator

* missing change in knn

* recofigure logic

* strip out dpnp/dpctl special code, will come back to it later

* switchover

* Update __init__.py

* Update test_array_api.py

* merge main into PR

* try to fix changes:

* add test:

* attempt to get things running

* remove tsne failure

* try again

* add sklearn_check_version

* oops

* update

* another try

* try again

* Update test_common.py

* Update test_common.py

* Update test_common.py

* Update test_common.py

* Update test_common.py

* fixes

* attempt to re-introduce changes

* Update _device_offload.py

* switch get_tags

* merge main into branch

* remove leftover code

* Update base.py

* formatting

* correct bad sklearn recommendation

* remove irrelevant new tests

* fix mistake

* lint fix

* add fallback test

* make test fixes

* formatting

* rejigger the logic again

* Update test_config.py

* interim fix

* Update test_config.py

* attempt to get test_config test to skip

* fix test

* interim fixes

* fixes for IncrementalEmpiricalCovariance

* fixes for wrap_output_data

* solve issues related to older numpy

* Update test_run_to_run_stability.py

* Update __init__.py

* Update __init__.py

* Update test_run_to_run_stability.py

* Update base.py

* Update test_run_to_run_stability.py

* Update test_config.py

* Update test_patching.py

* Update test_patching.py

* Update test_patching.py

* Update k_means.py

* Update k_means.py

* Update k_means.py

* Update _array_api.py

* Update _array_api.py

* Update _device_offload.py

* Update _device_offload.py

* Update _device_offload.py

* Update _device_offload.py

* Update test_patching.py

* Update test_patching.py

* Update test_patching.py

* local verify doclinks

* fix doclink

* fixes

* remove print statement

* fix changes from scikit-learn/scikit-learn#29774

* Update sklearnex/_device_offload.py

Co-authored-by: david-cortes-intel <david.cortes@intel.com>

* add type hints and fix docs for dispatch

* remove change from local testing

* be more explicit with a type

* python3.9 fix

* change language which was bothering me

* formatting

* This should fix it (I think)

* Apply suggestions from code review

Co-authored-by: david-cortes-intel <david.cortes@intel.com>

* Update base.py

* Update __init__.py

* remove hack on config

* Apply suggestions from code review

Co-authored-by: david-cortes-intel <david.cortes@intel.com>

* make suggested changes

* fix import in test_patching

* Update test_patching.py

* Update build-and-test-lnx.yml

* Update build-and-test-lnx.yml

* Update build-and-test-lnx.yml

* Apply suggestions from code review

Co-authored-by: david-cortes-intel <david.cortes@intel.com>

* Apply suggestions from code review

Co-authored-by: Andreas Huber <9201869+ahuber21@users.noreply.github.com>
Co-authored-by: david-cortes-intel <david.cortes@intel.com>

* formatting and ABC change

* fix gpu test

* remove workaround

---------

Co-authored-by: Faust, Ian <ian.faust@intel.com>
Co-authored-by: Ian Faust <icfaust@gmail.com>
Co-authored-by: david-cortes-intel <david.cortes@intel.com>
Co-authored-by: Andreas Huber <9201869+ahuber21@users.noreply.github.com>
(cherry picked from commit 43596ee)
@mergify mergify Bot mentioned this pull request Jun 13, 2025
13 tasks
@codecov
Copy link
Copy Markdown

codecov Bot commented Jun 13, 2025

Codecov Report

Attention: Patch coverage is 79.25532% with 39 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
sklearnex/_device_offload.py 66.66% 15 Missing and 4 partials ⚠️
sklearnex/base.py 75.75% 8 Missing ⚠️
onedal/utils/_sycl_queue_manager.py 25.00% 3 Missing and 3 partials ⚠️
sklearnex/linear_model/logistic_regression.py 40.00% 6 Missing ⚠️
Flag Coverage Δ
azure 79.90% <74.46%> (+0.13%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
onedal/_device_offload.py 82.88% <100.00%> (+2.70%) ⬆️
onedal/utils/_array_api.py 84.78% <100.00%> (+22.28%) ⬆️
sklearnex/_utils.py 83.56% <100.00%> (-2.01%) ⬇️
sklearnex/basic_statistics/basic_statistics.py 89.85% <100.00%> (+0.14%) ⬆️
...x/basic_statistics/incremental_basic_statistics.py 96.03% <100.00%> (+0.03%) ⬆️
sklearnex/cluster/dbscan.py 79.22% <100.00%> (+0.27%) ⬆️
sklearnex/cluster/k_means.py 89.65% <100.00%> (-0.15%) ⬇️
sklearnex/covariance/incremental_covariance.py 91.42% <100.00%> (+0.98%) ⬆️
sklearnex/decomposition/pca.py 88.35% <100.00%> (ø)
sklearnex/ensemble/_forest.py 80.03% <100.00%> (+0.03%) ⬆️
... and 16 more
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@napetrov napetrov merged commit c3ef708 into rls/2025.5.0-rls Jun 13, 2025
20 of 23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants