@@ -14,6 +14,7 @@ import com.intellij.openapi.command.executeCommand
14
14
import com.intellij.openapi.editor.Document
15
15
import com.intellij.openapi.editor.Editor
16
16
import com.intellij.openapi.fileTypes.FileType
17
+ import com.intellij.openapi.module.Module
17
18
import com.intellij.openapi.project.DumbService
18
19
import com.intellij.openapi.project.Project
19
20
import com.intellij.openapi.util.Computable
@@ -26,6 +27,7 @@ import com.intellij.psi.PsiDocumentManager
26
27
import com.intellij.psi.PsiElement
27
28
import com.intellij.psi.PsiFile
28
29
import com.intellij.psi.PsiFileFactory
30
+ import com.intellij.psi.PsiManager
29
31
import com.intellij.psi.PsiMethod
30
32
import com.intellij.psi.codeStyle.CodeStyleManager
31
33
import com.intellij.psi.codeStyle.JavaCodeStyleManager
@@ -76,6 +78,7 @@ import org.utbot.intellij.plugin.ui.TestsReportNotifier
76
78
import org.utbot.intellij.plugin.ui.WarningTestsReportNotifier
77
79
import org.utbot.intellij.plugin.ui.utils.getOrCreateSarifReportsPath
78
80
import org.utbot.intellij.plugin.ui.utils.showErrorDialogLater
81
+ import org.utbot.intellij.plugin.ui.utils.suitableTestSourceRoots
79
82
import org.utbot.intellij.plugin.util.RunConfigurationHelper
80
83
import org.utbot.intellij.plugin.util.signature
81
84
import org.utbot.sarif.SarifReport
@@ -148,7 +151,11 @@ object CodeGenerationController {
148
151
149
152
run (EDT_LATER ) {
150
153
waitForCountDown(latch, timeout = 100 , timeUnit = TimeUnit .MILLISECONDS ) {
151
- val existingUtilClass = model.codegenLanguage.getUtilClassOrNull(baseTestDirectory)
154
+ val project = model.project
155
+ val language = model.codegenLanguage
156
+ val testModule = model.testModule
157
+
158
+ val existingUtilClass = language.getUtilClassOrNull(project, testModule)
152
159
153
160
val utilClassKind = utilClassListener.requiredUtilClassKind
154
161
? : return @waitForCountDown // no util class needed
@@ -345,6 +352,28 @@ object CodeGenerationController {
345
352
}
346
353
}
347
354
355
+ /* *
356
+ * @param project project whose classes we generate tests for.
357
+ * @param testModule module where the generated tests will be placed.
358
+ * @return an existing util class from one of the test source roots
359
+ * in the given [testModule] or `null` if no util class was found.
360
+ */
361
+ private fun CodegenLanguage.getUtilClassOrNull (project : Project , testModule : Module ): PsiFile ? {
362
+ val psiManager = PsiManager .getInstance(project)
363
+
364
+ // all test roots for the given test module
365
+ val testRoots = runReadAction {
366
+ testModule
367
+ .suitableTestSourceRoots(this )
368
+ .mapNotNull { psiManager.findDirectory(it) }
369
+ }
370
+
371
+ // return an util class from one of the test source roots or null if no util class was found
372
+ return testRoots
373
+ .mapNotNull { testRoot -> getUtilClassOrNull(testRoot) }
374
+ .firstOrNull()
375
+ }
376
+
348
377
/* *
349
378
* Create all package directories for UtUtils class.
350
379
* @return the innermost directory - utils from `org.utbot.runtime.utils`
0 commit comments