Skip to content

Commit

Permalink
chore(controller): keep release api align with UI (#2981)
Browse files Browse the repository at this point in the history
  • Loading branch information
anda-ren authored Nov 16, 2023
1 parent 4964b97 commit ae67c62
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,11 @@ public ResponseEntity<ResponseMessage<FineTuneVo>> fineTuneInfo(
@PreAuthorize("hasAnyRole('OWNER', 'MAINTAINER')")
public ResponseEntity<ResponseMessage<String>> releaseFt(
@RequestParam Long ftId,
@RequestParam(required = false) String modelName
@RequestParam(required = false) String nonExistingModelName,
@RequestParam(required = false) Long existingModelId

) {
fineTuneAppService.releaseFt(ftId, modelName, userService.currentUserDetail());
fineTuneAppService.releaseFt(ftId, existingModelId, nonExistingModelName, userService.currentUserDetail());
return ResponseEntity.ok(Code.success.asResponse(""));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,16 @@ public void evalFt(List<Long> evalDatasetIds, Long runtimeId, String handerSpec,

}

/**
* release fintuned model to either existingModelId or nonExistingModelName
*
* @param ftId release fintuned model to
* @param existingModelId either existingModelId
* @param nonExistingModelName or nonExistingModelName
* @param user by user
*/
@Transactional
public void releaseFt(Long ftId, String modelName, User user) {
public void releaseFt(Long ftId, Long existingModelId, String nonExistingModelName, User user) {
FineTuneEntity ft = fineTuneMapper.findById(ftId);
if (null == ft) {
throw new SwNotFoundException(ResourceType.FINE_TUNE, "fine tune not found");
Expand All @@ -288,33 +296,36 @@ public void releaseFt(Long ftId, String modelName, User user) {
);
}
Long modelId;
if (!StringUtils.hasText(modelName) || modelVersion.getModelName().equals(modelName)) {
modelId = modelVersion.getModelId();
} else {
//release to a new model
if (null != existingModelId) {
if (!existingModelId.equals(modelVersion.getModelId())) {
ModelEntity model = modelDao.getModel(existingModelId);
if (null == model) {
throw new SwNotFoundException(
ResourceType.BUNDLE,
"modelId not found: "
);
}
}
modelId = existingModelId;
} else if (StringUtils.hasText(nonExistingModelName)) {
FineTuneSpaceEntity ftSpace = fineTuneSpaceMapper.findById(ft.getSpaceId());
Long projectId = ftSpace.getProjectId();
BundleEntity modelEntity = this.modelDao.findByNameForUpdate(modelName, projectId);
if (null == modelEntity) {
//create model
ModelEntity model = ModelEntity.builder()
.ownerId(user.getId())
.projectId(projectId)
.modelName(modelName)
.build();
modelDao.add(model);
modelId = model.getId();
} else {
modelId = modelEntity.getId();
BundleEntity modelEntity = this.modelDao.findByNameForUpdate(nonExistingModelName, projectId);
if (null != modelEntity) {
throw new SwValidationException(ValidSubject.MODEL, "model name existed");
}
ModelEntity model = ModelEntity.builder()
.ownerId(user.getId())
.projectId(projectId)
.modelName(nonExistingModelName)
.build();
modelDao.add(model);
modelId = model.getId();
} else {
throw new SwValidationException(ValidSubject.MODEL, "nonExistingModelName xor existingModelId is required");
}
// update model version model id to new model and set draft to false
modelDao.releaseModelVersion(targetModelVersionId, modelId);

}

public void attachTargetModel(Long id, ModelVersionEntity modelVersionEntity) {
fineTuneMapper.updateTargetModel(id, modelVersionEntity.getId());
}

private void checkFeatureEnabled() throws StarwhaleApiException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class FineTuneAppServiceTest {

FeaturesProperties featuresProperties;

User creator = User.builder().build();
JobConverter jobConverter;

@BeforeEach
Expand Down Expand Up @@ -162,64 +163,103 @@ void evalFt() {

@Test
void releaseFt() {
User creator = User.builder().build();
when(fineTuneMapper.findById(1L)).thenReturn(null);
Assertions.assertThrows(SwNotFoundException.class, () -> {
fineTuneAppService.releaseFt(1L, "", null);
});

when(fineTuneMapper.findById(2L)).thenReturn(
FineTuneEntity.builder()
.targetModelVersionId(null)
.build()
);
Assertions.assertThrows(SwNotFoundException.class, () -> {
fineTuneAppService.releaseFt(2L, "", null);
});

when(fineTuneMapper.findById(3L)).thenReturn(
when(fineTuneMapper.findById(5L)).thenReturn(
FineTuneEntity.builder()
.targetModelVersionId(4L)
.targetModelVersionId(6L)
.spaceId(1L)
.build()
);
when(modelDao.getModelVersion("4")).thenReturn(ModelVersionEntity
when(modelDao.findByNameForUpdate(any(), anyLong())).thenReturn(ModelEntity.builder().id(124L).build());
when(modelDao.getModelVersion("6")).thenReturn(ModelVersionEntity
.builder()
.draft(false)
.modelId(10L)
.modelName("aac")
.draft(true)
.build());
when(fineTuneSpaceMapper.findById(anyLong())).thenReturn(FineTuneSpaceEntity.builder().projectId(1L).build());
Assertions.assertThrows(SwValidationException.class, () -> {
fineTuneAppService.releaseFt(3L, "", null);
fineTuneAppService.releaseFt(5L, null, "aabc", creator);
});
}

@Test
void releaseAndCreateNew() {
when(fineTuneMapper.findById(5L)).thenReturn(
FineTuneEntity.builder()
.targetModelVersionId(6L)
.spaceId(1L)
.build()
);
when(fineTuneSpaceMapper.findById(anyLong())).thenReturn(FineTuneSpaceEntity.builder().projectId(1L).build());
when(modelDao.getModelVersion("6")).thenReturn(ModelVersionEntity
.builder()
.modelId(10L)
.modelName("aac")
.draft(true)
.build());
fineTuneAppService.releaseFt(5L, null, creator);
verify(modelDao).releaseModelVersion(6L, 10L);


when(fineTuneSpaceMapper.findById(anyLong())).thenReturn(FineTuneSpaceEntity.builder().projectId(1L).build());
doAnswer(new Answer() {
public Object answer(InvocationOnMock invocation) {
Object[] args = invocation.getArguments();
((ModelEntity) args[0]).setId(123L);
return null; // void method, so return null
}
}).when(modelDao).add(any());
fineTuneAppService.releaseFt(5L, "aab", creator);
fineTuneAppService.releaseFt(5L, null, "aab", creator);
verify(modelDao).releaseModelVersion(6L, 123L);
}

when(modelDao.findByNameForUpdate(any(), anyLong())).thenReturn(ModelEntity.builder().id(124L).build());
fineTuneAppService.releaseFt(5L, "aabc", creator);
verify(modelDao).releaseModelVersion(6L, 124L);
@Test
void releaseSuccessWithBaseModel() {
when(fineTuneMapper.findById(5L)).thenReturn(
FineTuneEntity.builder()
.targetModelVersionId(6L)
.spaceId(1L)
.build()
);
when(modelDao.getModelVersion("6")).thenReturn(ModelVersionEntity
.builder()
.modelId(10L)
.modelName("aac")
.draft(true)
.build());
fineTuneAppService.releaseFt(5L, 10L, null, creator);
verify(modelDao).releaseModelVersion(6L, 10L);
}

@Test
void testTargetVersionReleased() {
when(fineTuneMapper.findById(3L)).thenReturn(
FineTuneEntity.builder()
.targetModelVersionId(4L)
.build()
);
when(modelDao.getModelVersion("4")).thenReturn(ModelVersionEntity
.builder()
.draft(false)
.build());
Assertions.assertThrows(SwValidationException.class, () -> {
fineTuneAppService.releaseFt(3L, 1L, "", null);
});
}

@Test
void testTargetVersionNull() {
when(fineTuneMapper.findById(2L)).thenReturn(
FineTuneEntity.builder()
.targetModelVersionId(null)
.build()
);
Assertions.assertThrows(SwNotFoundException.class, () -> {
fineTuneAppService.releaseFt(2L, 1L, "", null);
});
}

@Test
void testReleaseFtNotFound() {
when(fineTuneMapper.findById(1L)).thenReturn(null);
Assertions.assertThrows(SwNotFoundException.class, () -> {
fineTuneAppService.releaseFt(1L, 1L, "", null);
});
}

@Test
Expand Down

0 comments on commit ae67c62

Please sign in to comment.