Skip to content

Commit 3b5d2a5

Browse files
bump dsl version,flatten,rebase
1 parent 5cedf53 commit 3b5d2a5

File tree

8 files changed

+351
-2
lines changed

8 files changed

+351
-2
lines changed

py-polars/docs/source/reference/io.rst

+9
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,12 @@ Configuration for cloud credential provisioning.
154154
CredentialProviderAWS
155155
CredentialProviderAzure
156156
CredentialProviderGCP
157+
158+
Scan Cast Options
159+
~~~~~~~~~~~~~~~~~
160+
Configuration for type-casting during scans.
161+
162+
.. autosummary::
163+
:toctree: api/
164+
165+
ScanCastOptions

py-polars/polars/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@
175175
PartitionByKey,
176176
PartitionMaxSize,
177177
PartitionParted,
178+
ScanCastOptions,
178179
defer,
179180
read_avro,
180181
read_clipboard,
@@ -284,6 +285,7 @@
284285
"PartitionByKey",
285286
"PartitionMaxSize",
286287
"PartitionParted",
288+
"ScanCastOptions",
287289
"read_avro",
288290
"read_clipboard",
289291
"read_csv",

py-polars/polars/io/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Functions for reading data."""
22

33
from polars.io.avro import read_avro
4+
from polars.io.cast_options import ScanCastOptions
45
from polars.io.clipboard import read_clipboard
56
from polars.io.csv import read_csv, read_csv_batched, scan_csv
67
from polars.io.database import read_database, read_database_uri
@@ -35,6 +36,7 @@
3536
"KeyedPartition",
3637
"BasePartitionContext",
3738
"KeyedPartitionContext",
39+
"ScanCastOptions",
3840
"read_avro",
3941
"read_clipboard",
4042
"read_csv",

py-polars/polars/io/cast_options.py

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Literal
4+
5+
from polars._utils.unstable import issue_unstable_warning
6+
7+
if TYPE_CHECKING:
8+
from collections.abc import Collection
9+
10+
from typing_extensions import TypeAlias
11+
12+
13+
FloatCastOption: TypeAlias = Literal["upcast", "downcast"]
14+
DatetimeCastOption: TypeAlias = Literal["nanosecond-downcast", "convert-timezone"]
15+
16+
17+
class ScanCastOptions:
18+
"""Options for type-casting when scanning files."""
19+
20+
def __init__(
21+
self,
22+
*,
23+
integer_cast: Literal["upcast", "forbid"] = "forbid",
24+
float_cast: Literal["forbid"]
25+
| FloatCastOption
26+
| Collection[FloatCastOption] = "forbid",
27+
datetime_cast: Literal["forbid"]
28+
| DatetimeCastOption
29+
| Collection[DatetimeCastOption] = "forbid",
30+
missing_struct_fields: Literal["insert", "raise"] = "raise",
31+
extra_struct_fields: Literal["ignore", "raise"] = "raise",
32+
_internal_call: bool = False,
33+
) -> None:
34+
"""
35+
Configuration for type-casting of columns when reading files.
36+
37+
This can be useful for scanning datasets with schemas that have been
38+
modified. This configuration object is generally passed to a supported
39+
`scan_*` function via the `cast_options` parameter.
40+
41+
.. warning::
42+
This functionality is considered **unstable**. It may be changed
43+
at any point without it being considered a breaking change.
44+
45+
Parameters
46+
----------
47+
integer_cast
48+
Configuration for casting from integer types:
49+
50+
* `upcast`: Allow lossless casting to wider integer types.
51+
* `forbid`: Raises an error if dtypes do not match.
52+
53+
float_cast
54+
Configuration for casting from float types:
55+
56+
* `upcast`: Allow casting to higher precision float types.
57+
* `downcast`: Allow casting to lower precision float types.
58+
* `forbid`: Raises an error if dtypes do not match.
59+
60+
datetime_cast
61+
Configuration for casting from datetime types:
62+
63+
* `nanosecond-downcast`: Allow nanosecond precision datetime to be \
64+
downcasted to any lower precision. This has a similar effect to \
65+
PyArrow's `coerce_int96_timestamp_unit`.
66+
* `convert-timezone`: Allow casting to a different timezone.
67+
* `forbid`: Raises an error if dtypes do not match.
68+
69+
missing_struct_fields
70+
Configuration for behavior when struct fields defined in the schema
71+
are missing from the data:
72+
73+
* `insert`: Inserts the missing fields.
74+
* `raise`: Raises an error.
75+
76+
extra_struct_fields
77+
Configuration for behavior when extra struct fields outside of the
78+
defined schema are encountered in the data:
79+
80+
* `ignore`: Silently ignores.
81+
* `raise`: Raises an error.
82+
83+
"""
84+
if not _internal_call:
85+
issue_unstable_warning("ScanCastOptions is considered unstable.")
86+
87+
self.integer_cast = integer_cast
88+
self.float_cast = float_cast
89+
self.datetime_cast = datetime_cast
90+
self.missing_struct_fields = missing_struct_fields
91+
self.extra_struct_fields = extra_struct_fields
92+
93+
# This is called from the Rust-side, we have it so that we don't accidentally
94+
# print unstable messages.
95+
@staticmethod
96+
def _default() -> ScanCastOptions:
97+
return ScanCastOptions(_internal_call=True)

py-polars/polars/io/parquet/functions.py

+16
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
from polars import DataFrame, DataType, LazyFrame
3737
from polars._typing import FileSource, ParallelStrategy, SchemaDict
38+
from polars.io.cast_options import ScanCastOptions
3839
from polars.io.cloud import CredentialProviderFunction
3940
from polars.io.cloud.credential_provider._builder import CredentialProviderBuilder
4041

@@ -384,6 +385,7 @@ def scan_parquet(
384385
retries: int = 2,
385386
include_file_paths: str | None = None,
386387
allow_missing_columns: bool = False,
388+
cast_options: ScanCastOptions | None = None,
387389
) -> LazyFrame:
388390
"""
389391
Lazily read from a local or cloud-hosted parquet file (or files).
@@ -491,6 +493,13 @@ def scan_parquet(
491493
raise an error. However, if `allow_missing_columns` is set to
492494
`True`, a full-NULL column is returned instead of erroring for the files
493495
that do not contain the column.
496+
cast_options
497+
Configuration for column type-casting during scans. Useful for datasets
498+
containing files that have differing schemas.
499+
500+
.. warning::
501+
This functionality is considered **unstable**. It may be changed
502+
at any point without it being considered a breaking change.
494503
495504
See Also
496505
--------
@@ -522,6 +531,10 @@ def scan_parquet(
522531
msg = "the `hive_schema` parameter of `scan_parquet` is considered unstable."
523532
issue_unstable_warning(msg)
524533

534+
if cast_options is not None:
535+
msg = "The `cast_options` parameter of `scan_parquet` is considered unstable."
536+
issue_unstable_warning(msg)
537+
525538
if isinstance(source, (str, Path)):
526539
source = normalize_filepath(source, check_not_directory=False)
527540
elif is_path_or_str_sequence(source):
@@ -553,6 +566,7 @@ def scan_parquet(
553566
glob=glob,
554567
include_file_paths=include_file_paths,
555568
allow_missing_columns=allow_missing_columns,
569+
cast_options=cast_options,
556570
)
557571

558572

@@ -577,6 +591,7 @@ def _scan_parquet_impl(
577591
retries: int = 2,
578592
include_file_paths: str | None = None,
579593
allow_missing_columns: bool = False,
594+
cast_options: ScanCastOptions | None = None,
580595
) -> LazyFrame:
581596
if isinstance(source, list):
582597
sources = source
@@ -610,5 +625,6 @@ def _scan_parquet_impl(
610625
glob=glob,
611626
include_file_paths=include_file_paths,
612627
allow_missing_columns=allow_missing_columns,
628+
cast_options=cast_options,
613629
)
614630
return wrap_ldf(pylf)

py-polars/tests/unit/io/test_lazy_parquet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,7 @@ def test_parquet_schema_arg(
805805

806806
with pytest.raises(
807807
pl.exceptions.SchemaError,
808-
match="data type mismatch for column b: expected: i8, found: i64",
808+
match="data type mismatch for column b: incoming: Int64 != target: Int8",
809809
):
810810
lf.collect(engine="streaming" if streaming else "in-memory")
811811

py-polars/tests/unit/io/test_multiscan.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,10 @@ def test_schema_mismatch_type_mismatch(
338338
if scan is pl.scan_ndjson
339339
else pytest.raises(
340340
pl.exceptions.SchemaError,
341-
match="data type mismatch for column xyz_col: expected: i64, found: str",
341+
match=(
342+
"data type mismatch for column xyz_col: "
343+
"incoming: String != target: Int64"
344+
),
342345
)
343346
)
344347

0 commit comments

Comments
 (0)