@@ -322,6 +322,114 @@ def test_insert(scalars_dfs, loc, column, value, allow_duplicates):
322
322
pd .testing .assert_frame_equal (bf_df .to_pandas (), pd_df , check_dtype = False )
323
323
324
324
325
+ def test_where_series_cond (scalars_df_index , scalars_pandas_df_index ):
326
+ # Condition is dataframe, other is None (as default).
327
+ cond_bf = scalars_df_index ["int64_col" ] > 0
328
+ cond_pd = scalars_pandas_df_index ["int64_col" ] > 0
329
+ bf_result = scalars_df_index .where (cond_bf ).to_pandas ()
330
+ pd_result = scalars_pandas_df_index .where (cond_pd )
331
+ pandas .testing .assert_frame_equal (bf_result , pd_result )
332
+
333
+
334
+ def test_where_series_multi_index (scalars_df_index , scalars_pandas_df_index ):
335
+ # Test when a dataframe has multi-index or multi-columns.
336
+ columns = ["int64_col" , "float64_col" ]
337
+ dataframe_bf = scalars_df_index [columns ]
338
+
339
+ dataframe_bf .columns = pd .MultiIndex .from_tuples (
340
+ [("str1" , 1 ), ("str2" , 2 )], names = ["STR" , "INT" ]
341
+ )
342
+ cond_bf = dataframe_bf ["str1" ] > 0
343
+
344
+ with pytest .raises (NotImplementedError ) as context :
345
+ dataframe_bf .where (cond_bf ).to_pandas ()
346
+ assert (
347
+ str (context .value )
348
+ == "The dataframe.where() method does not support multi-index and/or multi-column."
349
+ )
350
+
351
+
352
+ def test_where_series_cond_const_other (scalars_df_index , scalars_pandas_df_index ):
353
+ # Condition is a series, other is a constant.
354
+ columns = ["int64_col" , "float64_col" ]
355
+ dataframe_bf = scalars_df_index [columns ]
356
+ dataframe_pd = scalars_pandas_df_index [columns ]
357
+ dataframe_bf .columns .name = "test_name"
358
+ dataframe_pd .columns .name = "test_name"
359
+
360
+ cond_bf = dataframe_bf ["int64_col" ] > 0
361
+ cond_pd = dataframe_pd ["int64_col" ] > 0
362
+ other = 0
363
+
364
+ bf_result = dataframe_bf .where (cond_bf , other ).to_pandas ()
365
+ pd_result = dataframe_pd .where (cond_pd , other )
366
+ pandas .testing .assert_frame_equal (bf_result , pd_result )
367
+
368
+
369
+ def test_where_series_cond_dataframe_other (scalars_df_index , scalars_pandas_df_index ):
370
+ # Condition is a series, other is a dataframe.
371
+ columns = ["int64_col" , "float64_col" ]
372
+ dataframe_bf = scalars_df_index [columns ]
373
+ dataframe_pd = scalars_pandas_df_index [columns ]
374
+
375
+ cond_bf = dataframe_bf ["int64_col" ] > 0
376
+ cond_pd = dataframe_pd ["int64_col" ] > 0
377
+ other_bf = - dataframe_bf
378
+ other_pd = - dataframe_pd
379
+
380
+ bf_result = dataframe_bf .where (cond_bf , other_bf ).to_pandas ()
381
+ pd_result = dataframe_pd .where (cond_pd , other_pd )
382
+ pandas .testing .assert_frame_equal (bf_result , pd_result )
383
+
384
+
385
+ def test_where_dataframe_cond (scalars_df_index , scalars_pandas_df_index ):
386
+ # Condition is a dataframe, other is None.
387
+ columns = ["int64_col" , "float64_col" ]
388
+ dataframe_bf = scalars_df_index [columns ]
389
+ dataframe_pd = scalars_pandas_df_index [columns ]
390
+
391
+ cond_bf = dataframe_bf > 0
392
+ cond_pd = dataframe_pd > 0
393
+
394
+ bf_result = dataframe_bf .where (cond_bf , None ).to_pandas ()
395
+ pd_result = dataframe_pd .where (cond_pd , None )
396
+ pandas .testing .assert_frame_equal (bf_result , pd_result )
397
+
398
+
399
+ def test_where_dataframe_cond_const_other (scalars_df_index , scalars_pandas_df_index ):
400
+ # Condition is a dataframe, other is a constant.
401
+ columns = ["int64_col" , "float64_col" ]
402
+ dataframe_bf = scalars_df_index [columns ]
403
+ dataframe_pd = scalars_pandas_df_index [columns ]
404
+
405
+ cond_bf = dataframe_bf > 0
406
+ cond_pd = dataframe_pd > 0
407
+ other_bf = 10
408
+ other_pd = 10
409
+
410
+ bf_result = dataframe_bf .where (cond_bf , other_bf ).to_pandas ()
411
+ pd_result = dataframe_pd .where (cond_pd , other_pd )
412
+ pandas .testing .assert_frame_equal (bf_result , pd_result )
413
+
414
+
415
+ def test_where_dataframe_cond_dataframe_other (
416
+ scalars_df_index , scalars_pandas_df_index
417
+ ):
418
+ # Condition is a dataframe, other is a dataframe.
419
+ columns = ["int64_col" , "float64_col" ]
420
+ dataframe_bf = scalars_df_index [columns ]
421
+ dataframe_pd = scalars_pandas_df_index [columns ]
422
+
423
+ cond_bf = dataframe_bf > 0
424
+ cond_pd = dataframe_pd > 0
425
+ other_bf = dataframe_bf * 2
426
+ other_pd = dataframe_pd * 2
427
+
428
+ bf_result = dataframe_bf .where (cond_bf , other_bf ).to_pandas ()
429
+ pd_result = dataframe_pd .where (cond_pd , other_pd )
430
+ pandas .testing .assert_frame_equal (bf_result , pd_result )
431
+
432
+
325
433
def test_drop_column (scalars_dfs ):
326
434
scalars_df , scalars_pandas_df = scalars_dfs
327
435
col_name = "int64_col"
0 commit comments