ENH: Array API dispatching (backport #2096)#2545
Merged
Merged
Conversation
* ENH: array api dispatching added array-api-compat to test env * Deselect some scikit-learn Array API tests * deselect more tests * deselect more tests * disabled tests for * fix the deselection comment * disabled test for Ridge regression * Disabled tests and added comment * ENH: Array API dispatching * Revert adding dpctl into Array PI conformance testing added versioning for the get_nnamespace * minor refactoring onedal _array_api * add tests * addressed memory usage tests * Address some array api test fails * linting * addressed test_get_namespace * adding test case for validate_data check with Array API inputs * minor refactoring * addressed test_patch_map_match fail * Added docstrings for get_namespace * docstrings for Array API tests * updated minimal scikit-learn version for Array API dispatching * updated minimal scikit-learn version for Array API dispatching in _device_offload.py _array_api.py * fix test test_get_namespace_with_config_context * refactor onedal/datatypes/_data_conversion.py * correction for array api * Update conftest.py * introduce tags * fix imports * see if this works * really lazy logic introduction * introduce IntelEstimator * missing change in knn * recofigure logic * strip out dpnp/dpctl special code, will come back to it later * switchover * Update __init__.py * Update test_array_api.py * merge main into PR * try to fix changes: * add test: * attempt to get things running * remove tsne failure * try again * add sklearn_check_version * oops * update * another try * try again * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * fixes * attempt to re-introduce changes * Update _device_offload.py * switch get_tags * merge main into branch * remove leftover code * Update base.py * formatting * correct bad sklearn recommendation * remove irrelevant new tests * fix mistake * lint fix * add fallback test * make test fixes * formatting * rejigger the logic again * Update test_config.py * interim fix * Update test_config.py * attempt to get test_config test to skip * fix test * interim fixes * fixes for IncrementalEmpiricalCovariance * fixes for wrap_output_data * solve issues related to older numpy * Update test_run_to_run_stability.py * Update __init__.py * Update __init__.py * Update test_run_to_run_stability.py * Update base.py * Update test_run_to_run_stability.py * Update test_config.py * Update test_patching.py * Update test_patching.py * Update test_patching.py * Update k_means.py * Update k_means.py * Update k_means.py * Update _array_api.py * Update _array_api.py * Update _device_offload.py * Update _device_offload.py * Update _device_offload.py * Update _device_offload.py * Update test_patching.py * Update test_patching.py * Update test_patching.py * local verify doclinks * fix doclink * fixes * remove print statement * fix changes from scikit-learn/scikit-learn#29774 * Update sklearnex/_device_offload.py Co-authored-by: david-cortes-intel <david.cortes@intel.com> * add type hints and fix docs for dispatch * remove change from local testing * be more explicit with a type * python3.9 fix * change language which was bothering me * formatting * This should fix it (I think) * Apply suggestions from code review Co-authored-by: david-cortes-intel <david.cortes@intel.com> * Update base.py * Update __init__.py * remove hack on config * Apply suggestions from code review Co-authored-by: david-cortes-intel <david.cortes@intel.com> * make suggested changes * fix import in test_patching * Update test_patching.py * Update build-and-test-lnx.yml * Update build-and-test-lnx.yml * Update build-and-test-lnx.yml * Apply suggestions from code review Co-authored-by: david-cortes-intel <david.cortes@intel.com> * Apply suggestions from code review Co-authored-by: Andreas Huber <9201869+ahuber21@users.noreply.github.com> Co-authored-by: david-cortes-intel <david.cortes@intel.com> * formatting and ABC change * fix gpu test * remove workaround --------- Co-authored-by: Faust, Ian <ian.faust@intel.com> Co-authored-by: Ian Faust <icfaust@gmail.com> Co-authored-by: david-cortes-intel <david.cortes@intel.com> Co-authored-by: Andreas Huber <9201869+ahuber21@users.noreply.github.com> (cherry picked from commit 43596ee)
13 tasks
Codecov ReportAttention: Patch coverage is
Flags with carried forward coverage won't be shown. Click here to find out more.
🚀 New features to boost your workflow:
|
napetrov
approved these changes
Jun 13, 2025
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR enables the interface for scikit-learn-intelex estimators to dispatch Array API inputs to oneDAL. This removes the need to move data to host and back in the case of on-device data (see the need behind
use_raw_input) in properly enabled estimators. This PR will enablearray_api_dispatchingfor sklearn >=1.2 . As it is experimental in sklearn versions, the direct use of sklearn functions may fail (see Ensemble estimators'apply). Scikit-learn 1.2 does not have functional array checking with array_api_dispatch, so sklearnex/onedal array_api support with zero-copy will begin in earnest with sklearn 1.3. Enablement of the zero-copy will be for sklearn versions >=1.3 going forward.To accomplish this, the following related changes had to be made
oneDALEstimatoris created inbase.pywhich combines the capabilities ofPatchableEstimatorandExtensionEstimator. All sklearnex estimators must inherit this before sklearn'sBaseEstimatorin the mro. This is enforced by testing intest_common.py. This is needed in order to provide a tag for signifying an Array_API supported sklearnex estimator without cluttering the inheritance of the estimators. Note: as a consequence the HTML docs support is limited to the range of sklearn versions which are also supported (>=1.4), and is the logic is rewritten to minimize maintenance. This required changes to TSNE as well.get_tagsis introduced to check for sklearn's tags and to minimize maintenance. This duck-types the capability of recent (>=1.6) changes to the tag system to older sklearn versions, simplifying the code. This is needed to check for array_api support and onedal array_api support via the tag system with minimal maintenance for older versions.oneDALEstimatorsklearnex/_device_offload.pyis refactored. This new logic changes the cumbersome passing of strings from_get_backendinto a boolean flag. A second checkover of inputs was removed from that function.dispatchis extensively refactored in order to check for array_api support. The logic there is of particular note due to the increase in branching. It attempts at all times to minimize the checking or movement of data unless necessary, therefore requiring checking array_api support first (as this is data-movement free).test_standard_estimator_patchingandtest_special_estimator_patchingnow reference a combined_check_estimator_patchingfunction to reduce code duplication.test_standard_estimator_patchingnow includes array_api dispatching for the array-api-strict data type. Changes coming in dpnp 0.17 and dpctl 0.19 will make array_api_dispatch a necessity, requiring that for dataframe support for non-numpy, non-pandas inputs going forward. This is a significant task and must be done separately and also meaning further modification to the test suite.array_api_strictinputs are now done withconfig_context(array_api_dispatch=True)which verifies proper operation of sklearnex patched methods. This required changes toKMeans.transform,IncrementalEmpiricalCovariance.mahalanobis,IncrementalEmpiricalCovariance.score, andKMeans.fit_transformbeyond the changes to the standard dispatcher.test_fallback_to_hostis added due to the complexity ofdispatch. This validates that the various scenarios with offloading + fallbacks place the data properly while still preserving the queueThe following bugs needed to be solved to get things working properly:
is_dpctl_device_availablecalls from test_memory_usage need to be a list input, not a string.transform_outputused in _device_offloadsupport_input_formatandwrap_output_dataneed to check fordefaultand notNone, the array_api logic should only apply for non-numpy inputs, but should be able to convert numpy results. previously this logic was always skippedallow_fallback_to_hostis corrected for when used in conjunction withtarget_offload_asarraydoes not work for older numpy versions, as arrays do not have the__array_namespace__attribute. Following the guideline from the array_api standard, a check for the buffer protocol is added in line to match: https://github.com/data-apis/array-api-strict/blob/25cc3d7ff0069b222d228d380e6d95bbf9a5dbcf/array_api_strict/_creation_functions.py#L50.dtype-having objects which will then flatten, as array_api_strict arrays are very picky with the use of__iter__. This will simplify the analysis of 2d-array types, which have otherwise split them unnecessarily into 1d arrays for individual analysis.ElasticNetandLassoAdditional cleanup changes
test_patchingis removedLogisticRegression, as cleanup to the inheritance is undertaken in this PR.get_tagsinsklearnex.utilsmakes the functionget_array_api_support_tagunnecessary in sklearnex/_device_offload and is removed. It also attempts to match sklearn public API calls for simplicity in development.BaseEstimatoris removed, as this is inherited from the various sklearnSVMalgorithms in inheritanceuse_raw_inputindispatchis inverted and placed at the beginning for simplicity in analysis at no cost in computational time.As none of the
onedal/estimators are array_api enabled yet, the code is not fully tested, and additional PRs will be necessary to properly test and verify array_api support. These are not simple PRs and are to be included later before any estimator becomes array_api enabled. The initial testing should be sufficient for simple array_api support verification, and prevent non-functional changes for array_api support from entering the codebase.Follow up work must be made for support extracting queues from array_api inputs which comply with the new sycl_queue management system. As of this PR, it will not properly understand array_api devices. This will modify a separate file and must be a follow-on PR.
It is likely that this PR needs to be separated based on individual themes contained within that were necessary to enable the dispatching. Help from the reviewers is greatly desired (and can be based solely on the description).
NOTE: The expansion of testing for array_api to use
config_context(array_api_dispatch=True)has uncovered some bugs which were heretofore unnoticed. 1st the array_api dispatching ofIncrementalEmpiricalCovariance.scorefor sklearn <1.3 is not possible, and has been deselected. 2ndPCA.score_sampleandPCA.scoredo not currently work as the return values from PCA.fit (._precision and ._mean) need to be the same type as the inputXvalues. As of now, the return values and estimator attributes are not properly checked for type matching, and conversions to non-numpy types from oneDAL is inconsistently implemented. This is a significant undertaking and must be done as a follow-up with the individual array_api support in oneDAL while maintaining quality code. This will also require the completion of__dlpack__support to oneDAL table types, such thatfrom_tablecan return all array_api types.Checklist to comply with before moving PR from draft:
PR completeness and readability
Testing
Performance
This is an automatic backport of pull request #2096 done by [Mergify](https://mergify.com).