diff --git a/FAnsiSql/Discovery/DiscoveredServerHelper.cs b/FAnsiSql/Discovery/DiscoveredServerHelper.cs index e2cb03f9..c4256ab0 100644 --- a/FAnsiSql/Discovery/DiscoveredServerHelper.cs +++ b/FAnsiSql/Discovery/DiscoveredServerHelper.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Data.Common; using System.Linq; +using System.Runtime.CompilerServices; using System.Text.RegularExpressions; using System.Threading; using FAnsi.Connections; @@ -108,16 +109,16 @@ public virtual DbConnectionStringBuilder ChangeDatabase(DbConnectionStringBuilde public abstract IEnumerable ListDatabases(DbConnectionStringBuilder builder); public abstract IEnumerable ListDatabases(DbConnection con); - public IAsyncEnumerable ListDatabasesAsync(DbConnectionStringBuilder builder, CancellationToken token) + public async IAsyncEnumerable ListDatabasesAsync(DbConnectionStringBuilder builder, [EnumeratorCancellation] CancellationToken token) { //list the database on the server - using var con = GetConnection(builder); + await using var con = GetConnection(builder); //this will work or timeout - var openTask = con.OpenAsync(token); - openTask.Wait(token); + await con.OpenAsync(token); - return ListDatabases(con).ToAsyncEnumerable(); + foreach (var db in ListDatabases(con)) + yield return db; } public abstract DbConnectionStringBuilder EnableAsync(DbConnectionStringBuilder builder); diff --git a/Tests/FAnsiTests/Database/DatabaseLevelTests.cs b/Tests/FAnsiTests/Database/DatabaseLevelTests.cs index 7ae898c6..5c18f271 100644 --- a/Tests/FAnsiTests/Database/DatabaseLevelTests.cs +++ b/Tests/FAnsiTests/Database/DatabaseLevelTests.cs @@ -1,14 +1,21 @@ -using FAnsi; +using System; +using System.Collections; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using FAnsi; using FAnsi.Discovery; using FAnsi.Implementation; using NUnit.Framework; +using NUnit.Framework.Constraints; +using Oracle.ManagedDataAccess.Client; using TypeGuesser; namespace FAnsiTests.Database; internal sealed class DatabaseLevelTests : DatabaseTests { - [TestCaseSource(typeof(All),nameof(All.DatabaseTypes))] + [TestCaseSource(typeof(All), nameof(All.DatabaseTypes))] public void Database_Exists(DatabaseType type) { var server = GetTestDatabase(type); @@ -16,36 +23,57 @@ public void Database_Exists(DatabaseType type) } - [TestCase(DatabaseType.MySql,false)] - [TestCase(DatabaseType.MicrosoftSQLServer,false)] - [TestCase(DatabaseType.Oracle,true)] - [TestCase(DatabaseType.PostgreSql,false)] + [TestCase(DatabaseType.MySql, false)] + [TestCase(DatabaseType.MicrosoftSQLServer, false)] + [TestCase(DatabaseType.Oracle, true)] + [TestCase(DatabaseType.PostgreSql, false)] public void Test_ExpectDatabase(DatabaseType type, bool upperCase) { var helper = ImplementationManager.GetImplementation(type).GetServerHelper(); - var server = new DiscoveredServer(helper.GetConnectionStringBuilder("loco","db","frank","kangaro")); + var server = new DiscoveredServer(helper.GetConnectionStringBuilder("loco", "db", "frank", "kangaro")); var db = server.ExpectDatabase("omg"); - Assert.That(db.GetRuntimeName(), Is.EqualTo(upperCase ?"OMG":"omg")); + Assert.That(db.GetRuntimeName(), Is.EqualTo(upperCase ? "OMG" : "omg")); } - [TestCaseSource(typeof(All),nameof(All.DatabaseTypes))] + [TestCaseSource(typeof(All), nameof(All.DatabaseTypes))] public void Test_CreateSchema(DatabaseType type) { var db = GetTestDatabase(type); - Assert.DoesNotThrow(()=>db.CreateSchema("Fr ank")); - Assert.DoesNotThrow(()=>db.CreateSchema("Fr ank")); + Assert.DoesNotThrow(() => db.CreateSchema("Fr ank")); + Assert.DoesNotThrow(() => db.CreateSchema("Fr ank")); db.Server.GetQuerySyntaxHelper().EnsureWrapped("Fr ank"); if (type is not (DatabaseType.MicrosoftSQLServer or DatabaseType.PostgreSql)) return; var tbl = db.CreateTable("Heyyy", - [new DatabaseColumnRequest("fff", new DatabaseTypeRequest(typeof(string), 10))],"Fr ank"); + [new DatabaseColumnRequest("fff", new DatabaseTypeRequest(typeof(string), 10))], "Fr ank"); Assert.That(tbl.Exists()); - if(type == DatabaseType.MicrosoftSQLServer) + if (type == DatabaseType.MicrosoftSQLServer) Assert.That(tbl.Schema, Is.EqualTo("Fr ank")); } + + [TestCaseSource(typeof(All), nameof(All.DatabaseTypes))] + public void TestListDatabasesAsync(DatabaseType type) + { + var db = GetTestDatabase(type, false); + + Constraint exceptionType = type switch + { + DatabaseType.MySql => Throws.TypeOf(), + DatabaseType.MicrosoftSQLServer => Throws.TypeOf(), + DatabaseType.PostgreSql => Throws.Nothing, + DatabaseType.Oracle => Throws.TypeOf(), + _ => throw new ArgumentOutOfRangeException(nameof(type), type, null) + }; + + Assert.That( + () => db.Server.Helper.ListDatabasesAsync(db.Server.Builder, new CancellationToken(true)) + .ToBlockingEnumerable().ToList(), exceptionType); + var databases = db.Server.Helper.ListDatabasesAsync(db.Server.Builder, CancellationToken.None).ToBlockingEnumerable().ToList(); + Assert.That(databases, Has.Member(db.GetRuntimeName()).Using((IEqualityComparer)StringComparer.OrdinalIgnoreCase)); + } } \ No newline at end of file