-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
perf: unwrap cast for comparing ints =/!= strings #15110
base: main
Are you sure you want to change the base?
perf: unwrap cast for comparing ints =/!= strings #15110
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much @alan910127
I think this is looking quite close. I left a suggestion about how to do the conversion
I also think this needs a few more tests as suggested below.
FYI @jayzhan211 and @ion-elgreco
negative / null cases
Can you also test predicates with null literals and
-- compare with Null literal (ScalarValue::Utf8(None)
cast(c1, Utf8) = NULL
-- compare with strings that are not a number
cast(c1, Utf8) = 'not a number'
-- compare with strings that are numbers but do not fit in the range
-- where c1 is a Int8 so max value of 128, try comparing to a number
-- that is greater than 128
cast(c1, Utf8) = '1000'
"End to end" (sql level)
Can we please add at least one "end to end" style test in sqllogictest style (EXPLAIN ....
)
Here are the instructions: https://github.com/apache/datafusion/tree/main/datafusion/sqllogictest
Ideally you should be able to extend one of the existing test files in https://github.com/apache/datafusion/tree/main/datafusion/sqllogictest/test_files
| ScalarValue::Utf8View(Some(ref str)) | ||
| ScalarValue::LargeUtf8(Some(ref str)), | ||
) => match target_type { | ||
DataType::Int8 => str.parse::<i8>().ok().map(ScalarValue::from), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
by calling ok()
here I think this converts a predicate like cast(col, utf8) = 'foo'
into col = 'null'
Which is not quite the same thing.
Also, str.parse()
likely has subltely different semantics than the arrow cast kernels.
This, can can you please make this function:
- Return
None
if the conversion fails (following the convention of other functions in this module liketry_cast_dictionary
) - Use the cast kernel to parse the values?
The cast
kernel can be called by:
- Converting ScalarValue to an array with
ScalarValue::to_array
- Call
cast
: https://docs.rs/arrow/latest/arrow/array/cast/index.html - Calling
ScalarValue::try_from_array
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to use your suggested casting implementation, could you help me check if I'm implementing it correctly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also 2 sqllogictest test cases are added (optimized/number not in range)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much @alan910127 -- this is great
While reviewing this PR I had a few more tests I wanted to suggest, but instead of doing another round of PR reviews I figured I would just push the commits directly to save time.
I also verified that DuckDB does this type of simplification too (note the CAST( │ │ '123' AS STRUCT(x INTEGER)
which is casting the string ot int
D create table t (x int);
D insert into t values (1);
D explain select * from t where t = '123';
┌─────────────────────────────┐
│┌───────────────────────────┐│
││ Physical Plan ││
│└───────────────────────────┘│
└─────────────────────────────┘
┌───────────────────────────┐
│ FILTER │
│ ──────────────────── │
│ (struct_pack(x) = CAST( │
│ '123' AS STRUCT(x INTEGER)│
│ )) │
│ │
│ ~1 Rows │
└─────────────┬─────────────┘
┌─────────────┴─────────────┐
│ SEQ_SCAN │
│ ──────────────────── │
│ Table: t │
│ Type: Sequential Scan │
│ Projections: x │
│ │
│ ~1 Rows │
└───────────────────────────┘
Strangely they don't seem to care about errors:
D select * from t where t = '999999999999999999999999999999999999';
Conversion Error:
Type VARCHAR with value '999999999999999999999999999999999999' can't be cast to the destination type STRUCT
LINE 1: select * from t where t = '999999999999999999999999999999999999';
D select * from t where t = 'foo';
Conversion Error:
Type VARCHAR with value 'foo' can't be cast to the destination type STRUCT
LINE 1: select * from t where t = 'foo';
^
D select * from t;
┌───────┐
│ x │
│ int32 │
├───────┤
│ 1 │
└───────┘
target_type: &DataType, | ||
op: Operator, | ||
) -> Option<ScalarValue> { | ||
macro_rules! cast_or_else_return_none { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since ScalarValue
is already dynamic, I don't think we need a macro to run on each type
I played around a little locally and this is what I came up with. What do you think?
/// Try to move a cast from a literal to the other side of a `=` / `!=` operator
///
/// Specifically, rewrites
/// ```sql
/// cast(col) <op> <litera>
/// ```
///
/// To
///
/// ```sql
/// col <op> cast(<literal>)
/// col <op> <cast_literal>
/// ```
fn cast_literal_to_type_with_op(
lit_value: &ScalarValue,
target_type: &DataType,
op: Operator,
) -> Option<ScalarValue> {
let (
Operator::Eq | Operator::NotEq,
ScalarValue::Utf8(Some(_))
| ScalarValue::Utf8View(Some(_))
| ScalarValue::LargeUtf8(Some(_)),
) = (op, lit_value)
else {
return None;
};
// Only try for integer types (TODO can we do this for other types
// like timestamps)?
use DataType::*;
if matches!(
target_type,
Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
) {
let opts = arrow::compute::CastOptions {
safe: false,
format_options: Default::default(),
};
let array = ScalarValue::to_array(lit_value).ok()?;
let casted =
arrow::compute::cast_with_options(&array, target_type, &opts).ok()?;
ScalarValue::try_from_array(&casted, 0).ok()
} else {
None
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of doing another round of PR reviews I figured I would just push the commits directly to save time.
Thank you! I’m just curious—is storing a parquet file in slt tests common? The tests I’ve worked with so far use tables created in the same test file.
Strangely they don't seem to care about errors:
Interesting! This makes me wonder if they treat this behavior as an "optimization" or if it's simply an expected behavior to cast the literal to the column's type.
I played around a little locally and this is what I came up with. What do you think?
Oh, this looks much better! However, I’d prefer leaving the outer match as is, maybe there’s potential for another optimization in the future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if they treat this behavior as an "optimization" or if it's simply an expected behavior to cast the literal to the column's type
It looks like my guess might be correct—they always cast the literal to the column's type, regardless of the operator.
D create table t as values (1), (2), (3);
D select * from t where col0 < '10';
┌───────┐
│ col0 │
│ int32 │
├───────┤
│ 1 │
│ 2 │
│ 3 │
└───────┘
@findepi I wonder if you have some time to double check the correctness of this optimization (distributing a cast) |
@@ -1758,7 +1758,7 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> { | |||
// try_cast/cast(expr as data_type) op literal | |||
Expr::BinaryExpr(BinaryExpr { left, op, right }) | |||
if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary( | |||
info, &left, &right, | |||
info, &left, &right, op, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add op
between left and right args -- this helps understand the role of the two &Expr
args to the function
(or maybe even pass the whole &BinaryExpr
into the function)
ScalarValue::Utf8(Some(_)) | ||
| ScalarValue::Utf8View(Some(_)) | ||
| ScalarValue::LargeUtf8(Some(_)), | ||
) => match target_type { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does target_type represent source type of the cast?
unwrapping cast is not an easy feat. The current implementation isn't exactly correct, as it conflates different strings which cast back to the same number: CREATE OR REPLACE TABLE t AS SELECT arrow_cast('123', 'Int64') a;
-- correctly finds the row
SELECT * FROM t WHERE cast(a AS string) = '123';
-- incorrectly also finds the row
SELECT * FROM t WHERE cast(a AS string) = '0123'; |
@findepi so the test cases in |
I think @findepi is saying that the mapping proposed in the PR is more subtle casting an integer It would not match the string However, as written this PR will also match So in other words, we should not do the conversion from Maybe we can add a check that the string would be the same when we round tripped it As in only do the rewrite if the sequence
results in the exact same string as went int |
(thank you for checking this @findepi -- 🙏 |
neat |
match (op, lit_value) { | ||
( | ||
Operator::Eq | Operator::NotEq, | ||
ScalarValue::Utf8(Some(_)) | ||
| ScalarValue::Utf8View(Some(_)) | ||
| ScalarValue::LargeUtf8(Some(_)), | ||
) => match target_type { | ||
DataType::Int8 => cast_or_else_return_none!(lit_value, DataType::Int8), | ||
DataType::Int16 => cast_or_else_return_none!(lit_value, DataType::Int16), | ||
DataType::Int32 => cast_or_else_return_none!(lit_value, DataType::Int32), | ||
DataType::Int64 => cast_or_else_return_none!(lit_value, DataType::Int64), | ||
DataType::UInt8 => cast_or_else_return_none!(lit_value, DataType::UInt8), | ||
DataType::UInt16 => cast_or_else_return_none!(lit_value, DataType::UInt16), | ||
DataType::UInt32 => cast_or_else_return_none!(lit_value, DataType::UInt32), | ||
DataType::UInt64 => cast_or_else_return_none!(lit_value, DataType::UInt64), | ||
_ => None, | ||
}, | ||
_ => None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As it's written, and as the function is named, this looks like generic approach that will work many different source/target type pairs, we just need to add logic.
However, the exact checks necessary for soundness will likely differ.
For casts between exact numbers (ints and decimals) and string types, the #15110 (comment) looks like a sound plan.
Plus, we could also simplify impossible target literals like this
cast(an_int as string) = 'abc'
orcast(an_int as string) = '007'
toif(an_int is null, null, false)
For casts between numbers and numbers, we likely need to recognize max addressible values, and cast impreciseness, especially when we want to handle <
, <=
, >
, >=
operators.
https://github.com/trinodb/trino/blob/a870656f406b0f76e97c740659a58554d472994d/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java#L199-L354 might serve as a good reference. Note how it handles NaN, max values, lossy round-tripping, sometimes converting <
to <=
as necessary.
For e.g cast from timestamp to date, the logic will likely be different again
https://github.com/trinodb/trino/blob/a870656f406b0f76e97c740659a58554d472994d/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java#L357-L389.
With all this in mind, I'd suggest structuring this code so it's clear it addresses only the exact number and string case, by checking expr and literal types and delegating to a function that has "int" and "utf8" or "string" in the name.
@alamb I found that this approach is different from how postgres and duckdb handle this situation:
postgresdb=# create table t as select 123 a;
SELECT 1
db=# select * from t where a = '0123';
a
-----
123
(1 row)
db=# select * from t where cast(a AS text) = '0123';
a
---
(0 rows)
duckdbD create table t as select 123 a;
D select * from t where a = '0123';
┌───────┐
│ a │
│ int32 │
├───────┤
│ 123 │
└───────┘
D select * from t where cast(a AS string) = '0123';
┌────────┐
│ a │
│ int32 │
├────────┤
│ 0 rows │
└────────┘ I'm not sure if it's only possible to handle it when the |
That may well be the right thing to do 🤔 It is funny that @scsmithr filed a ticket this morning with a similar symptom: 🤔 |
This reverts commit 808d6ab.
This cast unwrapping is valuable orthogonally to what coercions get applied by the frontend. And it looks feasible -- the plan in #15110 (comment) IMO should work (for integer and string type pairs). |
For context, I was running some implicit cast queries on a few different systems to compare outputs, and hit the aforementioned issue with datafusion. "Unexpected" might be a better word than "incorrect", but datafusion is the outlier here. My take on this: I don't believe manipulating casts after-the-fact is the right approach. Given the query Without a proper "unknown" type, I think it should be sufficient to rank implicitly casting to utf8 last compared to any other implicit cast. |
@findepi sorry I didn't notice your comment and I just pushed a new version with the unwrapping logic deleted. So you think the two optimizations should coexist?
But we have no way to tell if the cast is set by the user or not in the unwrap cast function. Doesn't this matter? |
I agree @scsmithr -- For the case in #15161 I agree that changing the coercion rules as you suggest is a better plan than trying to fix it up afterwards However, I think the optimization in this PR is valuable for cases when the user has ended up with the casts some other way and DataFusion is now optimizing the expression
That is my understanding and also what I think should happen
I don't think the source of the cast matters (for this particular optimization) |
Ah, yes. Since the string -> int -> string is checked and we're only checking equality, then they'll be the same. But I' m not sure if this unwrapping would happen after the string literal coercion.
Got it, should I revert this PR to the unwrap cast optimization or directly combine both the optimization and the literal coercion in this PR? |
Maybe we can figure out if it is possible to parse to the idea type for scalar, so '123' can be casted to i64 not str and '1.2' can be casted to float not str. |
I try a bit and I think there are still some works to be done in how we parse the sql. #15202. |
Actually I'm quite curious is string literal really an issue? If we want string, we can have query with quote |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To summarize the feedback so far:
I think this PR needs one more tests as suggested by @findepi #15110 (comment)
-- should not remove cast on string
explain select a from t where cast(a as string) = '0100';
Here is a suggestion on making that pass: #15110 (comment)
Maybe we can add @findepi 's analysis in as comments too
I recommend that we file separate tickets for future work (such as supporting other types).
Does this make sense @alan910127 ?
@@ -330,12 +330,12 @@ async fn test_create_physical_expr_coercion() { | |||
create_expr_test(lit(1i32).eq(col("id")), "CAST(1 AS Utf8) = id@0"); | |||
// compare int col to string literal `i = '202410'` | |||
// Note this casts the column (not the field) | |||
create_expr_test(col("i").eq(lit("202410")), "CAST(i@1 AS Utf8) = 202410"); | |||
create_expr_test(lit("202410").eq(col("i")), "202410 = CAST(i@1 AS Utf8)"); | |||
create_expr_test(col("i").eq(lit("202410")), "i@1 = 202410"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🎉
Which issue does this PR close?
Rationale for this change
What changes are included in this PR?
unwrap_cast_in_comparison_for_binary
andis_cast_expr_and_support_unwrap_cast_in_comparison_for_binary
to use the new function.Are these changes tested?
The following test cases are added:
CAST(c1, UTF8) < "123"
CAST(c1, UTF8) = "123"
CAST(c1, UTF8) != "123"
CAST(c6, UTF8) = "123"
CAST(c6, UTF8) != "123"
Are there any user-facing changes?
No.