Skip to content

[CALCITE-4771] add TRY_CAST function (enabled in MSSQL library) #3136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions core/src/main/codegen/templates/Parser.jj
Original file line number Diff line number Diff line change
Expand Up @@ -6004,6 +6004,7 @@ SqlNode BuiltinFunctionCall() :
(
( <CAST> { f = SqlStdOperatorTable.CAST; }
| <SAFE_CAST> { f = SqlLibraryOperators.SAFE_CAST; }
| <TRY_CAST> { f = SqlLibraryOperators.TRY_CAST; }
)
{ s = span(); }
<LPAREN> AddExpression(args, ExprContext.ACCEPT_SUB_QUERY)
Expand Down Expand Up @@ -8353,6 +8354,7 @@ SqlPostfixOperator PostfixRowOperator() :
| < TRIM_ARRAY: "TRIM_ARRAY" >
| < TRUE: "TRUE" >
| < TRUNCATE: "TRUNCATE" >
| < TRY_CAST: "TRY_CAST" >
| < TUESDAY: "TUESDAY" >
| < TUMBLE: "TUMBLE" >
| < TYPE: "TYPE" >
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@
import static org.apache.calcite.sql.fun.SqlLibraryOperators.TO_BASE64;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.TRANSLATE3;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.TRUNC;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.TRY_CAST;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.UNIX_DATE;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.UNIX_MICROS;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.UNIX_MILLIS;
Expand Down Expand Up @@ -665,6 +666,7 @@ Builder populate2() {
map.put(COALESCE, new CoalesceImplementor());
map.put(CAST, new CastImplementor());
map.put(SAFE_CAST, new CastImplementor());
map.put(TRY_CAST, new CastImplementor());

map.put(REINTERPRET, new ReinterpretImplementor());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ public class SqlCastFunction extends SqlFunction {
//~ Constructors -----------------------------------------------------------

public SqlCastFunction() {
this(SqlKind.CAST);
this(SqlKind.CAST.toString(), SqlKind.CAST);
}

public SqlCastFunction(SqlKind kind) {
super(kind.toString(), kind, returnTypeInference(kind == SqlKind.SAFE_CAST),
public SqlCastFunction(String name, SqlKind kind) {
super(name, kind, returnTypeInference(kind == SqlKind.SAFE_CAST),
InferTypes.FIRST_KNOWN, null, SqlFunctionCategory.SYSTEM);
checkArgument(kind == SqlKind.CAST || kind == SqlKind.SAFE_CAST, kind);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1196,7 +1196,12 @@ static RelDataType deriveTypeSplit(SqlOperatorBinding operatorBinding,
* error. */
@LibraryOperator(libraries = {BIG_QUERY})
public static final SqlFunction SAFE_CAST =
new SqlCastFunction(SqlKind.SAFE_CAST);
new SqlCastFunction("SAFE_CAST", SqlKind.SAFE_CAST);

/** The "TRY_CAST(expr AS type)" function, equivalent to SAFE_CAST. */
@LibraryOperator(libraries = {MSSQL})
public static final SqlFunction TRY_CAST =
new SqlCastFunction("TRY_CAST", SqlKind.SAFE_CAST);

/** NULL-safe "&lt;=&gt;" equal operator used by MySQL, for example
* {@code 1<=>NULL}. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ private StandardConvertletTable() {
// Register convertlets for specific objects.
registerOp(SqlStdOperatorTable.CAST, this::convertCast);
registerOp(SqlLibraryOperators.SAFE_CAST, this::convertCast);
registerOp(SqlLibraryOperators.TRY_CAST, this::convertCast);
registerOp(SqlLibraryOperators.INFIX_CAST, this::convertCast);
registerOp(SqlStdOperatorTable.IS_DISTINCT_FROM,
(cx, call) -> convertIsDistinctFrom(cx, call, false));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ class SqlAdvisorTest extends SqlValidatorTestCase {
"KEYWORD(TRIM)",
"KEYWORD(TRUE)",
"KEYWORD(TRUNCATE)",
"KEYWORD(TRY_CAST)",
"KEYWORD(UNIQUE)",
"KEYWORD(UNKNOWN)",
"KEYWORD(UPPER)",
Expand Down
2 changes: 2 additions & 0 deletions site/_docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -1061,6 +1061,7 @@ TRIGGER_SCHEMA,
**TRIM_ARRAY**,
**TRUE**,
**TRUNCATE**,
**TRY_CAST**,
**TUESDAY**,
TUMBLE,
TYPE,
Expand Down Expand Up @@ -2758,6 +2759,7 @@ BigQuery's type system uses confusingly different names for types and functions:
| o p | TO_TIMESTAMP(string, format) | Converts *string* to a timestamp using the format *format*
| b o p | TRANSLATE(expr, fromString, toString) | Returns *expr* with all occurrences of each character in *fromString* replaced by its corresponding character in *toString*. Characters in *expr* that are not in *fromString* are not replaced
| b | TRUNC(numeric1 [, numeric2 ]) | Truncates *numeric1* to optionally *numeric2* (if not specified 0) places right to the decimal point
| q | TRY_CAST(value AS type) | Converts *value* to *type*, returning NULL if conversion fails
| b | UNIX_MICROS(timestamp) | Returns the number of microseconds since 1970-01-01 00:00:00
| b | UNIX_MILLIS(timestamp) | Returns the number of milliseconds since 1970-01-01 00:00:00
| b | UNIX_SECONDS(timestamp) | Returns the number of seconds since 1970-01-01 00:00:00
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -597,81 +597,96 @@ default SqlOperatorFixture forOracle(SqlConformance conformance) {
.with("fun", "oracle"));
}

/**
* Types for cast.
*/
enum CastType {
CAST("cast"),
SAFE_CAST("safe_cast"),
TRY_CAST("try_cast");

CastType(String name) {
this.name = name;
}

final String name;
}

default String getCastString(
String value,
String targetType,
boolean errorLoc,
boolean safe) {
CastType castType) {
if (errorLoc) {
value = "^" + value + "^";
}
String function = safe ? "safe_cast" : "cast";
String function = castType.name;
return function + "(" + value + " as " + targetType + ")";
}

default void checkCastToApproxOkay(String value, String targetType,
Object expected, boolean safe) {
checkScalarApprox(getCastString(value, targetType, false, safe),
getTargetType(targetType, safe), expected);
Object expected, CastType castType) {
checkScalarApprox(getCastString(value, targetType, false, castType),
getTargetType(targetType, castType), expected);
}

default void checkCastToStringOkay(String value, String targetType,
String expected, boolean safe) {
final String castString = getCastString(value, targetType, false, safe);
checkString(castString, expected, getTargetType(targetType, safe));
String expected, CastType castType) {
final String castString = getCastString(value, targetType, false, castType);
checkString(castString, expected, getTargetType(targetType, castType));
}

default void checkCastToScalarOkay(String value, String targetType,
String expected, boolean safe) {
final String castString = getCastString(value, targetType, false, safe);
checkScalarExact(castString, getTargetType(targetType, safe), expected);
String expected, CastType castType) {
final String castString = getCastString(value, targetType, false, castType);
checkScalarExact(castString, getTargetType(targetType, castType), expected);
}

default String getTargetType(String targetType, boolean safe) {
return safe ? targetType : targetType + NON_NULLABLE_SUFFIX;
default String getTargetType(String targetType, CastType castType) {
return castType == CastType.CAST ? targetType + NON_NULLABLE_SUFFIX : targetType;
}

default void checkCastToScalarOkay(String value, String targetType,
boolean safe) {
checkCastToScalarOkay(value, targetType, value, safe);
CastType castType) {
checkCastToScalarOkay(value, targetType, value, castType);
}

default void checkCastFails(String value, String targetType,
String expectedError, boolean runtime, boolean safe) {
final String castString = getCastString(value, targetType, !runtime, safe);
String expectedError, boolean runtime, CastType castType) {
final String castString = getCastString(value, targetType, !runtime, castType);
checkFails(castString, expectedError, runtime);
}

default void checkCastToString(String value, @Nullable String type,
@Nullable String expected, boolean safe) {
@Nullable String expected, CastType castType) {
String spaces = " ";
if (expected == null) {
expected = value.trim();
}
int len = expected.length();
if (type != null) {
value = getCastString(value, type, false, safe);
value = getCastString(value, type, false, castType);
}

// currently no exception thrown for truncation
if (Bug.DT239_FIXED) {
checkCastFails(value,
"VARCHAR(" + (len - 1) + ")", STRING_TRUNC_MESSAGE,
true, safe);
true, castType);
}

checkCastToStringOkay(value, "VARCHAR(" + len + ")", expected, safe);
checkCastToStringOkay(value, "VARCHAR(" + (len + 5) + ")", expected, safe);
checkCastToStringOkay(value, "VARCHAR(" + len + ")", expected, castType);
checkCastToStringOkay(value, "VARCHAR(" + (len + 5) + ")", expected, castType);

// currently no exception thrown for truncation
if (Bug.DT239_FIXED) {
checkCastFails(value,
"CHAR(" + (len - 1) + ")", STRING_TRUNC_MESSAGE,
true, safe);
true, castType);
}

checkCastToStringOkay(value, "CHAR(" + len + ")", expected, safe);
checkCastToStringOkay(value, "CHAR(" + len + ")", expected, castType);
checkCastToStringOkay(value, "CHAR(" + (len + 5) + ")",
expected + spaces, safe);
expected + spaces, castType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ private SqlOperatorFixtures() {
}

/** Returns a fixture that converts each CAST test into a test for
* SAFE_CAST. */
static SqlOperatorFixture safeCastWrapper(SqlOperatorFixture fixture) {
* SAFE_CAST or TRY_CAST. */
static SqlOperatorFixture safeCastWrapper(SqlOperatorFixture fixture, String functionName) {
return (SqlOperatorFixture) Proxy.newProxyInstance(
SqlOperatorTest.class.getClassLoader(),
new Class[]{SqlOperatorFixture.class},
new SqlOperatorFixtureInvocationHandler(fixture));
new SqlOperatorFixtureInvocationHandler(fixture, functionName));
}

/** A helper for {@link #safeCastWrapper(SqlOperatorFixture)} that provides
/** A helper for {@link #safeCastWrapper(SqlOperatorFixture, String)} that provides
* alternative implementations of methods in {@link SqlOperatorFixture}.
*
* <p>Must be public, so that its methods can be seen via reflection. */
Expand All @@ -49,29 +49,31 @@ public static class SqlOperatorFixtureInvocationHandler
static final Pattern NOT_NULL_PATTERN = Pattern.compile(" NOT NULL");

final SqlOperatorFixture f;
final String functionName;

SqlOperatorFixtureInvocationHandler(SqlOperatorFixture f) {
SqlOperatorFixtureInvocationHandler(SqlOperatorFixture f, String functionName) {
this.f = f;
this.functionName = functionName;
}

@Override protected Object getTarget() {
return f;
}

String addSafe(String sql) {
return CAST_PATTERN.matcher(sql).replaceAll("SAFE_CAST(");
return CAST_PATTERN.matcher(sql).replaceAll(functionName + "(");
}

String removeNotNull(String type) {
return NOT_NULL_PATTERN.matcher(type).replaceAll("");
}

/** Proxy for
* {@link SqlOperatorFixture#checkCastToString(String, String, String, boolean)}. */
* {@link SqlOperatorFixture#checkCastToString(String, String, String, SqlOperatorFixture.CastType)}. */
public void checkCastToString(String value, @Nullable String type,
@Nullable String expected, boolean safe) {
@Nullable String expected, SqlOperatorFixture.CastType castType) {
f.checkCastToString(addSafe(value),
type == null ? null : removeNotNull(type), expected, safe);
type == null ? null : removeNotNull(type), expected, castType);
}

/** Proxy for {@link SqlOperatorFixture#checkBoolean(String, Boolean)}. */
Expand Down
Loading