1818
1919import dataclasses
2020import functools
21- from typing import cast , Union
21+ import io
22+ import itertools
23+ import json
24+ from typing import Any , Callable , cast , Generator , Iterable , Literal , Optional , Union
2225import uuid
2326
2427import geopandas # type: ignore
2528import numpy as np
2629import pandas
2730import pyarrow as pa
31+ import pyarrow .parquet # type: ignore
2832
2933import bigframes .core .schema as schemata
3034import bigframes .dtypes
@@ -42,7 +46,9 @@ def from_arrow(cls, table: pa.Table) -> LocalTableMetadata:
4246
4347_MANAGED_STORAGE_TYPES_OVERRIDES : dict [bigframes .dtypes .Dtype , pa .DataType ] = {
4448 # wkt to be precise
45- bigframes .dtypes .GEO_DTYPE : pa .string ()
49+ bigframes .dtypes .GEO_DTYPE : pa .string (),
50+ # Just json as string
51+ bigframes .dtypes .JSON_DTYPE : pa .string (),
4652}
4753
4854
@@ -90,6 +96,50 @@ def from_pyarrow(self, table: pa.Table) -> ManagedArrowTable:
9096 schemata .ArraySchema (tuple (fields )),
9197 )
9298
99+ def to_parquet (
100+ self ,
101+ dst : Union [str , io .IOBase ],
102+ * ,
103+ offsets_col : Optional [str ] = None ,
104+ geo_format : Literal ["wkb" , "wkt" ] = "wkt" ,
105+ duration_type : Literal ["int" , "duration" ] = "duration" ,
106+ json_type : Literal ["string" ] = "string" ,
107+ ):
108+ pa_table = self .data
109+ if offsets_col is not None :
110+ pa_table = pa_table .append_column (
111+ offsets_col , pa .array (range (pa_table .num_rows ), type = pa .int64 ())
112+ )
113+ if geo_format != "wkt" :
114+ raise NotImplementedError (f"geo format { geo_format } not yet implemented" )
115+ if duration_type != "duration" :
116+ raise NotImplementedError (
117+ f"duration as { duration_type } not yet implemented"
118+ )
119+ assert json_type == "string"
120+ pyarrow .parquet .write_table (pa_table , where = dst )
121+
122+ def itertuples (
123+ self ,
124+ * ,
125+ geo_format : Literal ["wkb" , "wkt" ] = "wkt" ,
126+ duration_type : Literal ["int" , "timedelta" ] = "timedelta" ,
127+ json_type : Literal ["string" , "object" ] = "string" ,
128+ ) -> Iterable [tuple ]:
129+ """
130+ Yield each row as an unlabeled tuple.
131+
132+ Row-wise iteration of columnar data is slow, avoid if possible.
133+ """
134+ for row_dict in _iter_table (
135+ self .data ,
136+ self .schema ,
137+ geo_format = geo_format ,
138+ duration_type = duration_type ,
139+ json_type = json_type ,
140+ ):
141+ yield tuple (row_dict .values ())
142+
93143 def validate (self ):
94144 # TODO: Content-based validation for some datatypes (eg json, wkt, list) where logical domain is smaller than pyarrow type
95145 for bf_field , arrow_field in zip (self .schema .items , self .data .schema ):
@@ -101,11 +151,78 @@ def validate(self):
101151 )
102152
103153
104- def _get_managed_storage_type (dtype : bigframes .dtypes .Dtype ) -> pa .DataType :
105- if dtype in _MANAGED_STORAGE_TYPES_OVERRIDES .keys ():
106- return _MANAGED_STORAGE_TYPES_OVERRIDES [dtype ]
107- else :
108- return bigframes .dtypes .bigframes_dtype_to_arrow_dtype (dtype )
154+ # Sequential iterator, but could split into batches and leverage parallelism for speed
155+ def _iter_table (
156+ table : pa .Table ,
157+ schema : schemata .ArraySchema ,
158+ * ,
159+ geo_format : Literal ["wkb" , "wkt" ] = "wkt" ,
160+ duration_type : Literal ["int" , "timedelta" ] = "timedelta" ,
161+ json_type : Literal ["string" , "object" ] = "string" ,
162+ ) -> Generator [dict [str , Any ], None , None ]:
163+ """For when you feel like iterating row-wise over a column store. Don't expect speed."""
164+
165+ if geo_format != "wkt" :
166+ raise NotImplementedError (f"geo format { geo_format } not yet implemented" )
167+
168+ @functools .singledispatch
169+ def iter_array (
170+ array : pa .Array , dtype : bigframes .dtypes .Dtype
171+ ) -> Generator [Any , None , None ]:
172+ values = array .to_pylist ()
173+ if dtype == bigframes .dtypes .JSON_DTYPE :
174+ if json_type == "object" :
175+ yield from map (lambda x : json .loads (x ) if x is not None else x , values )
176+ else :
177+ yield from values
178+ elif dtype == bigframes .dtypes .TIMEDELTA_DTYPE :
179+ if duration_type == "int" :
180+ yield from map (
181+ lambda x : ((x .days * 3600 * 24 ) + x .seconds ) * 1_000_000
182+ + x .microseconds
183+ if x is not None
184+ else x ,
185+ values ,
186+ )
187+ else :
188+ yield from values
189+ else :
190+ yield from values
191+
192+ @iter_array .register
193+ def _ (
194+ array : pa .ListArray , dtype : bigframes .dtypes .Dtype
195+ ) -> Generator [Any , None , None ]:
196+ value_generator = iter_array (
197+ array .flatten (), bigframes .dtypes .get_array_inner_type (dtype )
198+ )
199+ for (start , end ) in itertools .pairwise (array .offsets ):
200+ arr_size = end .as_py () - start .as_py ()
201+ yield list (itertools .islice (value_generator , arr_size ))
202+
203+ @iter_array .register
204+ def _ (
205+ array : pa .StructArray , dtype : bigframes .dtypes .Dtype
206+ ) -> Generator [Any , None , None ]:
207+ # yield from each subarray
208+ sub_generators : dict [str , Generator [Any , None , None ]] = {}
209+ for field_name , dtype in bigframes .dtypes .get_struct_fields (dtype ).items ():
210+ sub_generators [field_name ] = iter_array (array .field (field_name ), dtype )
211+
212+ keys = list (sub_generators .keys ())
213+ for row_values in zip (* sub_generators .values ()):
214+ yield {key : value for key , value in zip (keys , row_values )}
215+
216+ for batch in table .to_batches ():
217+ sub_generators : dict [str , Generator [Any , None , None ]] = {}
218+ for field in schema .items :
219+ sub_generators [field .column ] = iter_array (
220+ batch .column (field .column ), field .dtype
221+ )
222+
223+ keys = list (sub_generators .keys ())
224+ for row_values in zip (* sub_generators .values ()):
225+ yield {key : value for key , value in zip (keys , row_values )}
109226
110227
111228def _adapt_pandas_series (
@@ -117,32 +234,63 @@ def _adapt_pandas_series(
117234 return pa .array (series , type = pa .string ()), bigframes .dtypes .GEO_DTYPE
118235 try :
119236 return _adapt_arrow_array (pa .array (series ))
120- except Exception as e :
237+ except pa . ArrowInvalid as e :
121238 if series .dtype == np .dtype ("O" ):
122239 try :
123- series = series .astype (bigframes .dtypes .GEO_DTYPE )
240+ return _adapt_pandas_series ( series .astype (bigframes .dtypes .GEO_DTYPE ) )
124241 except TypeError :
242+ # Prefer original error
125243 pass
126244 raise e
127245
128246
129247def _adapt_arrow_array (
130248 array : Union [pa .ChunkedArray , pa .Array ]
131249) -> tuple [Union [pa .ChunkedArray , pa .Array ], bigframes .dtypes .Dtype ]:
132- target_type = _arrow_type_replacements (array .type )
250+ target_type = _logical_type_replacements (array .type )
133251 if target_type != array .type :
134252 # TODO: Maybe warn if lossy conversion?
135253 array = array .cast (target_type )
136254 bf_type = bigframes .dtypes .arrow_dtype_to_bigframes_dtype (target_type )
255+
137256 storage_type = _get_managed_storage_type (bf_type )
138257 if storage_type != array .type :
139- raise TypeError (
140- f"Expected { bf_type } to use arrow { storage_type } , instead got { array .type } "
141- )
258+ array = array .cast (storage_type )
142259 return array , bf_type
143260
144261
145- def _arrow_type_replacements (type : pa .DataType ) -> pa .DataType :
262+ def _get_managed_storage_type (dtype : bigframes .dtypes .Dtype ) -> pa .DataType :
263+ if dtype in _MANAGED_STORAGE_TYPES_OVERRIDES .keys ():
264+ return _MANAGED_STORAGE_TYPES_OVERRIDES [dtype ]
265+ return _physical_type_replacements (
266+ bigframes .dtypes .bigframes_dtype_to_arrow_dtype (dtype )
267+ )
268+
269+
270+ def _recursive_map_types (
271+ f : Callable [[pa .DataType ], pa .DataType ]
272+ ) -> Callable [[pa .DataType ], pa .DataType ]:
273+ @functools .wraps (f )
274+ def recursive_f (type : pa .DataType ) -> pa .DataType :
275+ if pa .types .is_list (type ):
276+ new_field_t = recursive_f (type .value_type )
277+ if new_field_t != type .value_type :
278+ return pa .list_ (new_field_t )
279+ return type
280+ if pa .types .is_struct (type ):
281+ struct_type = cast (pa .StructType , type )
282+ new_fields : list [pa .Field ] = []
283+ for i in range (struct_type .num_fields ):
284+ field = struct_type .field (i )
285+ new_fields .append (field .with_type (recursive_f (field .type )))
286+ return pa .struct (new_fields )
287+ return f (type )
288+
289+ return recursive_f
290+
291+
292+ @_recursive_map_types
293+ def _logical_type_replacements (type : pa .DataType ) -> pa .DataType :
146294 if pa .types .is_timestamp (type ):
147295 # This is potentially lossy, but BigFrames doesn't support ns
148296 new_tz = "UTC" if (type .tz is not None ) else None
@@ -160,21 +308,24 @@ def _arrow_type_replacements(type: pa.DataType) -> pa.DataType:
160308 if pa .types .is_large_string (type ):
161309 # simple string type can handle the largest strings needed
162310 return pa .string ()
311+ if pa .types .is_dictionary (type ):
312+ return _logical_type_replacements (type .value_type )
163313 if pa .types .is_null (type ):
164314 # null as a type not allowed, default type is float64 for bigframes
165315 return pa .float64 ()
166- if pa .types .is_list (type ):
167- new_field_t = _arrow_type_replacements (type .value_type )
168- if new_field_t != type .value_type :
169- return pa .list_ (new_field_t )
170- return type
171- if pa .types .is_struct (type ):
172- struct_type = cast (pa .StructType , type )
173- new_fields : list [pa .Field ] = []
174- for i in range (struct_type .num_fields ):
175- field = struct_type .field (i )
176- field .with_type (_arrow_type_replacements (field .type ))
177- new_fields .append (field .with_type (_arrow_type_replacements (field .type )))
178- return pa .struct (new_fields )
179316 else :
180317 return type
318+
319+
320+ _ARROW_MANAGED_STORAGE_OVERRIDES = {
321+ bigframes .dtypes ._BIGFRAMES_TO_ARROW [bf_dtype ]: arrow_type
322+ for bf_dtype , arrow_type in _MANAGED_STORAGE_TYPES_OVERRIDES .items ()
323+ if bf_dtype in bigframes .dtypes ._BIGFRAMES_TO_ARROW
324+ }
325+
326+
327+ @_recursive_map_types
328+ def _physical_type_replacements (dtype : pa .DataType ) -> pa .DataType :
329+ if dtype in _ARROW_MANAGED_STORAGE_OVERRIDES :
330+ return _ARROW_MANAGED_STORAGE_OVERRIDES [dtype ]
331+ return dtype
0 commit comments