Skip to content

Commit

Permalink
Merge branch '8.0-maint' into backport/pr1863_to_8.0
Browse files Browse the repository at this point in the history
  • Loading branch information
lauxjpn authored Mar 16, 2024
2 parents e5c221a + 73aab49 commit 80fb595
Show file tree
Hide file tree
Showing 5 changed files with 377 additions and 9 deletions.
152 changes: 152 additions & 0 deletions src/EFCore.MySql/Extensions/MySqlDbFunctionsExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,158 @@ namespace Microsoft.EntityFrameworkCore
/// </summary>
public static class MySqlDbFunctionsExtensions
{
#region ConvertTimeZone

/// <summary>
/// Converts the `DateTime` value <paramref name="dateTime"/> from the time zone given by <paramref name="fromTimeZone"/> to the time zone given by <paramref name="toTimeZone"/> and returns the resulting value.
/// Corresponds to `CONVERT_TZ(dateTime, fromTimeZone, toTimeZone)`.
/// </summary>
/// <param name="_">The DbFunctions instance.</param>
/// <param name="dateTime">The `DateTime` value to convert.</param>
/// <param name="fromTimeZone">The time zone to convert from.</param>
/// <param name="toTimeZone">The time zone to convert to.</param>
/// <returns>The converted value.</returns>
public static DateTime? ConvertTimeZone(
[CanBeNull] this DbFunctions _,
DateTime dateTime,
string fromTimeZone,
string toTimeZone)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(ConvertTimeZone)));

/// <summary>
/// Converts the `DateOnly` value <paramref name="dateOnly"/> from the time zone given by <paramref name="fromTimeZone"/> to the time zone given by <paramref name="toTimeZone"/> and returns the resulting value.
/// Corresponds to `CONVERT_TZ(dateTime, fromTimeZone, toTimeZone)`..
/// </summary>
/// <param name="_">The DbFunctions instance.</param>
/// <param name="dateOnly">The `DateOnly` value to convert.</param>
/// <param name="fromTimeZone">The time zone to convert from.</param>
/// <param name="toTimeZone">The time zone to convert to.</param>
/// <returns>The converted value.</returns>
public static DateOnly? ConvertTimeZone(
[CanBeNull] this DbFunctions _,
DateOnly dateOnly,
string fromTimeZone,
string toTimeZone)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(ConvertTimeZone)));

/// <summary>
/// Converts the `DateTime?` value <paramref name="dateTime"/> from the time zone given by <paramref name="fromTimeZone"/> to the time zone given by <paramref name="toTimeZone"/> and returns the resulting value.
/// Corresponds to `CONVERT_TZ(dateTime, fromTimeZone, toTimeZone)`.
/// </summary>
/// <param name="_">The DbFunctions instance.</param>
/// <param name="dateTime">The `DateTime?` value to convert.</param>
/// <param name="fromTimeZone">The time zone to convert from.</param>
/// <param name="toTimeZone">The time zone to convert to.</param>
/// <returns>The converted value.</returns>
public static DateTime? ConvertTimeZone(
[CanBeNull] this DbFunctions _,
DateTime? dateTime,
string fromTimeZone,
string toTimeZone)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(ConvertTimeZone)));

/// <summary>
/// Converts the `DateOnly?` value <paramref name="dateOnly"/> from the time zone given by <paramref name="fromTimeZone"/> to the time zone given by <paramref name="toTimeZone"/> and returns the resulting value.
/// Corresponds to `CONVERT_TZ(dateTime, fromTimeZone, toTimeZone)`..
/// </summary>
/// <param name="_">The DbFunctions instance.</param>
/// <param name="dateOnly">The `DateOnly?` value to convert.</param>
/// <param name="fromTimeZone">The time zone to convert from.</param>
/// <param name="toTimeZone">The time zone to convert to.</param>
/// <returns>The converted value.</returns>
public static DateOnly? ConvertTimeZone(
[CanBeNull] this DbFunctions _,
DateOnly? dateOnly,
string fromTimeZone,
string toTimeZone)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(ConvertTimeZone)));

/// <summary>
/// Converts the `DateTime` value <paramref name="dateTime"/> from `@@session.time_zone` to the time zone given by <paramref name="toTimeZone"/> and returns the resulting value.
/// Corresponds to `CONVERT_TZ(dateTime, @@session.time_zone, toTimeZone)`.
/// </summary>
/// <param name="_">The DbFunctions instance.</param>
/// <param name="dateTime">The `DateTime` value to convert.</param>
/// <param name="toTimeZone">The time zone to convert to.</param>
/// <returns>The converted value.</returns>
public static DateTime? ConvertTimeZone(
[CanBeNull] this DbFunctions _,
DateTime dateTime,
string toTimeZone)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(ConvertTimeZone)));

/// <summary>
/// Converts the `DateTimeOffset` value <paramref name="dateTimeOffset"/> from `+00:00`/UTC to the time zone given by <paramref name="toTimeZone"/> and returns the resulting value as a `DateTime`.
/// Corresponds to `CONVERT_TZ(dateTime, '+00:00', toTimeZone)`.
/// </summary>
/// <param name="_">The DbFunctions instance.</param>
/// <param name="dateTimeOffset">The `DateTimeOffset` value to convert.</param>
/// <param name="toTimeZone">The time zone to convert to.</param>
/// <returns>The converted `DateTime?` value.</returns>
public static DateTime? ConvertTimeZone(
[CanBeNull] this DbFunctions _,
DateTimeOffset dateTimeOffset,
string toTimeZone)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(ConvertTimeZone)));

/// <summary>
/// Converts the `DateOnly` value <paramref name="dateOnly"/> from `@@session.time_zone` to the time zone given by <paramref name="toTimeZone"/> and returns the resulting value.
/// Corresponds to `CONVERT_TZ(dateTime, @@session.time_zone, toTimeZone)`.
/// </summary>
/// <param name="_">The DbFunctions instance.</param>
/// <param name="dateOnly">The `DateOnly` value to convert.</param>
/// <param name="toTimeZone">The time zone to convert to.</param>
/// <returns>The converted value.</returns>
public static DateOnly? ConvertTimeZone(
[CanBeNull] this DbFunctions _,
DateOnly dateOnly,
string toTimeZone)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(ConvertTimeZone)));

/// <summary>
/// Converts the `DateTime?` value <paramref name="dateTime"/> from `@@session.time_zone` to the time zone given by <paramref name="toTimeZone"/> and returns the resulting value.
/// Corresponds to `CONVERT_TZ(dateTime, @@session.time_zone, toTimeZone)`.
/// </summary>
/// <param name="_">The DbFunctions instance.</param>
/// <param name="dateTime">The `DateTime?` value to convert.</param>
/// <param name="toTimeZone">The time zone to convert to.</param>
/// <returns>The converted value.</returns>
public static DateTime? ConvertTimeZone(
[CanBeNull] this DbFunctions _,
DateTime? dateTime,
string toTimeZone)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(ConvertTimeZone)));

/// <summary>
/// Converts the `DateTimeOffset?` value <paramref name="dateTimeOffset"/> from `+00:00`/UTC to the time zone given by <paramref name="toTimeZone"/> and returns the resulting value as a `DateTime`.
/// Corresponds to `CONVERT_TZ(dateTime, '+00:00', toTimeZone)`.
/// </summary>
/// <param name="_">The DbFunctions instance.</param>
/// <param name="dateTimeOffset">The `DateTimeOffset?` value to convert.</param>
/// <param name="toTimeZone">The time zone to convert to.</param>
/// <returns>The converted `DateTime?` value.</returns>
public static DateTime? ConvertTimeZone(
[CanBeNull] this DbFunctions _,
DateTimeOffset? dateTimeOffset,
string toTimeZone)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(ConvertTimeZone)));

/// <summary>
/// Converts the `DateOnly?` value <paramref name="dateOnly"/> from `@@session.time_zone` to the time zone given by <paramref name="toTimeZone"/> and returns the resulting value.
/// Corresponds to `CONVERT_TZ(dateTime, @@session.time_zone, toTimeZone)`.
/// </summary>
/// <param name="_">The DbFunctions instance.</param>
/// <param name="dateOnly">The `DateOnly?` value to convert.</param>
/// <param name="toTimeZone">The time zone to convert to.</param>
/// <returns>The converted value.</returns>
public static DateOnly? ConvertTimeZone(
[CanBeNull] this DbFunctions _,
DateOnly? dateOnly,
string toTimeZone)
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(ConvertTimeZone)));

#endregion ConvertTimeZone

/// <summary>
/// Counts the number of year boundaries crossed between the startDate and endDate.
/// Corresponds to TIMESTAMPDIFF(YEAR,startDate,endDate).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
using Microsoft.EntityFrameworkCore.Storage;
using Pomelo.EntityFrameworkCore.MySql.Query.Internal;
using Pomelo.EntityFrameworkCore.MySql.Utilities;

namespace Pomelo.EntityFrameworkCore.MySql.Query.ExpressionTranslators.Internal
{
Expand All @@ -22,6 +23,40 @@ public class MySqlDbFunctionsExtensionsMethodTranslator : IMethodCallTranslator
{
private readonly MySqlSqlExpressionFactory _sqlExpressionFactory;

private static readonly HashSet<MethodInfo> _convertTimeZoneMethodInfos =
[
typeof(MySqlDbFunctionsExtensions).GetRuntimeMethod(
nameof(MySqlDbFunctionsExtensions.ConvertTimeZone),
new[] { typeof(DbFunctions), typeof(DateTime), typeof(string), typeof(string) }),
typeof(MySqlDbFunctionsExtensions).GetRuntimeMethod(
nameof(MySqlDbFunctionsExtensions.ConvertTimeZone),
new[] { typeof(DbFunctions), typeof(DateOnly), typeof(string), typeof(string) }),
typeof(MySqlDbFunctionsExtensions).GetRuntimeMethod(
nameof(MySqlDbFunctionsExtensions.ConvertTimeZone),
new[] { typeof(DbFunctions), typeof(DateTime?), typeof(string), typeof(string) }),
typeof(MySqlDbFunctionsExtensions).GetRuntimeMethod(
nameof(MySqlDbFunctionsExtensions.ConvertTimeZone),
new[] { typeof(DbFunctions), typeof(DateOnly?), typeof(string), typeof(string) }),
typeof(MySqlDbFunctionsExtensions).GetRuntimeMethod(
nameof(MySqlDbFunctionsExtensions.ConvertTimeZone),
new[] { typeof(DbFunctions), typeof(DateTime), typeof(string) }),
typeof(MySqlDbFunctionsExtensions).GetRuntimeMethod(
nameof(MySqlDbFunctionsExtensions.ConvertTimeZone),
new[] { typeof(DbFunctions), typeof(DateTimeOffset), typeof(string) }),
typeof(MySqlDbFunctionsExtensions).GetRuntimeMethod(
nameof(MySqlDbFunctionsExtensions.ConvertTimeZone),
new[] { typeof(DbFunctions), typeof(DateOnly), typeof(string) }),
typeof(MySqlDbFunctionsExtensions).GetRuntimeMethod(
nameof(MySqlDbFunctionsExtensions.ConvertTimeZone),
new[] { typeof(DbFunctions), typeof(DateTime?), typeof(string) }),
typeof(MySqlDbFunctionsExtensions).GetRuntimeMethod(
nameof(MySqlDbFunctionsExtensions.ConvertTimeZone),
new[] { typeof(DbFunctions), typeof(DateTimeOffset?), typeof(string) }),
typeof(MySqlDbFunctionsExtensions).GetRuntimeMethod(
nameof(MySqlDbFunctionsExtensions.ConvertTimeZone),
new[] { typeof(DbFunctions), typeof(DateOnly?), typeof(string) }),
];

private static readonly Type[] _supportedLikeTypes = {
typeof(int),
typeof(long),
Expand Down Expand Up @@ -148,6 +183,29 @@ public virtual SqlExpression Translate(
IReadOnlyList<SqlExpression> arguments,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
if (_convertTimeZoneMethodInfos.TryGetValue(method, out _))
{
// Will not just return `NULL` if any of its parameters is `NULL`, but also if `fromTimeZone` or `toTimeZone` is incorrect.
// Will do no conversion at all if `dateTime` is outside the supported range.
return _sqlExpressionFactory.NullableFunction(
"CONVERT_TZ",
arguments.Count == 3
?
[
arguments[1],
// The implicit fromTimeZone is UTC for DateTimeOffset values and the current session time zone otherwise.
method.GetParameters()[1].ParameterType.UnwrapNullableType() == typeof(DateTimeOffset)
? _sqlExpressionFactory.Constant("+00:00")
: _sqlExpressionFactory.Fragment("@@session.time_zone"),
arguments[2]
]
: new[] { arguments[1], arguments[2], arguments[3] },
method.ReturnType.UnwrapNullableType(),
null,
false,
Statics.GetTrueValues(arguments.Count));
}

if (_likeMethodInfos.Any(m => Equals(method, m)))
{
var match = _sqlExpressionFactory.ApplyDefaultTypeMapping(arguments[1]);
Expand Down
37 changes: 28 additions & 9 deletions src/EFCore.MySql/Storage/Internal/MySqlRelationalConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ protected static DbDataSource GetEffectiveDataSource(IMySqlOptions mySqlSingleto
=> mySqlSingletonOptions.DataSource ??
contextOptions.FindExtension<CoreOptionsExtension>()?.ApplicationServiceProvider?.GetService<MySqlDataSource>();

// TODO: Remove, because we don't use it anywhere.
private bool IsMasterConnection { get; set; }

protected override DbConnection CreateDbConnection()
Expand Down Expand Up @@ -255,7 +254,9 @@ public override bool Open(bool errorsExpected = false)

if (result)
{
if (_mySqlOptionsExtension.UpdateSqlModeOnOpen && _mySqlOptionsExtension.NoBackslashEscapes)
if (_mySqlOptionsExtension.UpdateSqlModeOnOpen &&
_mySqlOptionsExtension.NoBackslashEscapes &&
!IsMasterConnection)
{
AddSqlMode(NoBackslashEscapes);
}
Expand All @@ -271,9 +272,11 @@ public override async Task<bool> OpenAsync(CancellationToken cancellationToken,

if (result)
{
if (_mySqlOptionsExtension.UpdateSqlModeOnOpen && _mySqlOptionsExtension.NoBackslashEscapes)
if (_mySqlOptionsExtension.UpdateSqlModeOnOpen &&
_mySqlOptionsExtension.NoBackslashEscapes &&
!IsMasterConnection)
{
await AddSqlModeAsync(NoBackslashEscapes)
await AddSqlModeAsync(NoBackslashEscapes, cancellationToken)
.ConfigureAwait(false);
}
}
Expand All @@ -282,16 +285,32 @@ await AddSqlModeAsync(NoBackslashEscapes)
}

public virtual void AddSqlMode(string mode)
=> Dependencies.CurrentContext.Context?.Database.ExecuteSqlInterpolated($@"SET SESSION sql_mode = CONCAT(@@sql_mode, ',', {mode});");
=> ExecuteNonQuery($@"SET SESSION sql_mode = CONCAT(@@sql_mode, ',', '{mode}');");

public virtual Task AddSqlModeAsync(string mode, CancellationToken cancellationToken = default)
=> Dependencies.CurrentContext.Context?.Database.ExecuteSqlInterpolatedAsync($@"SET SESSION sql_mode = CONCAT(@@sql_mode, ',', {mode});", cancellationToken);
=> ExecuteNonQueryAsync($@"SET SESSION sql_mode = CONCAT(@@sql_mode, ',', '{mode}');", cancellationToken);

public virtual void RemoveSqlMode(string mode)
=> Dependencies.CurrentContext.Context?.Database.ExecuteSqlInterpolated($@"SET SESSION sql_mode = REPLACE(@@sql_mode, {mode}, '');");
=> ExecuteNonQuery($@"SET SESSION sql_mode = REPLACE(@@sql_mode, '{mode}', '');");

public virtual void RemoveSqlModeAsync(string mode, CancellationToken cancellationToken = default)
=> Dependencies.CurrentContext.Context?.Database.ExecuteSqlInterpolatedAsync($@"SET SESSION sql_mode = REPLACE(@@sql_mode, {mode}, '');", cancellationToken);
public virtual Task RemoveSqlModeAsync(string mode, CancellationToken cancellationToken = default)
=> ExecuteNonQueryAsync($@"SET SESSION sql_mode = REPLACE(@@sql_mode, '{mode}', '');", cancellationToken);

protected virtual void ExecuteNonQuery(string sql)
{
using var command = DbConnection.CreateCommand();
command.CommandText = sql;
command.ExecuteNonQuery();
}

protected virtual async Task ExecuteNonQueryAsync(string sql, CancellationToken cancellationToken = default)
{
var command = DbConnection.CreateCommand();
await using (command.ConfigureAwait(false))
{
command.CommandText = sql;
await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
}
}
}
}
27 changes: 27 additions & 0 deletions test/EFCore.MySql.FunctionalTests/ConnectionMySqlTest.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.Extensions.DependencyInjection;
using MySqlConnector;
using Pomelo.EntityFrameworkCore.MySql.FunctionalTests.TestUtilities;
Expand Down Expand Up @@ -113,6 +114,32 @@ public void Can_create_admin_connection_with_connection()
masterConnection.Open();
}

[Fact]
public void Can_create_database_with_disablebackslashescaping()
{
var optionsBuilder = new DbContextOptionsBuilder<GeneralOptionsContext>();
optionsBuilder.UseMySql(MySqlTestStore.CreateConnectionString("ConnectionTest_" + Guid.NewGuid()), AppConfig.ServerVersion, b => b.ApplyConfiguration().DisableBackslashEscaping());
using var context = new GeneralOptionsContext(optionsBuilder.Options);

var relationalDatabaseCreator = context.GetService<IRelationalDatabaseCreator>();

try
{
relationalDatabaseCreator.EnsureCreated();
}
finally
{
try
{
relationalDatabaseCreator.EnsureDeleted();
}
catch
{
// ignored
}
}
}

private readonly IServiceProvider _serviceProvider = new ServiceCollection()
.AddEntityFrameworkMySql()
.BuildServiceProvider();
Expand Down
Loading

0 comments on commit 80fb595

Please sign in to comment.