Skip to content

Commit

Permalink
feat(controller): fill SFT list fields (#2980)
Browse files Browse the repository at this point in the history
  • Loading branch information
anda-ren authored Nov 15, 2023
1 parent 4a93aca commit ee53554
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ public PageInfo<DatasetVersionVo> listDatasetVersionHistory(DatasetVersionQuery
}

public List<DatasetVo> findDatasetsByVersionIds(List<Long> versionIds) {
if (versionIds.isEmpty()) {
if (CollectionUtils.isEmpty(versionIds)) {
return List.of();
}
List<DatasetVersionEntity> versions = datasetVersionMapper.findByIds(Joiner.on(",").join(versionIds));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
package ai.starwhale.mlops.domain.ft;

import ai.starwhale.mlops.api.protocol.ft.FineTuneCreateRequest;
import ai.starwhale.mlops.api.protocol.model.ModelVo;
import ai.starwhale.mlops.common.Constants;
import ai.starwhale.mlops.common.IdConverter;
import ai.starwhale.mlops.configuration.FeaturesProperties;
import ai.starwhale.mlops.domain.bundle.base.BundleEntity;
import ai.starwhale.mlops.domain.dataset.DatasetDao;
import ai.starwhale.mlops.domain.dataset.DatasetService;
import ai.starwhale.mlops.domain.dataset.bo.DatasetVersion;
import ai.starwhale.mlops.domain.ft.mapper.FineTuneMapper;
import ai.starwhale.mlops.domain.ft.mapper.FineTuneSpaceMapper;
Expand All @@ -31,13 +33,15 @@
import ai.starwhale.mlops.domain.job.JobCreator;
import ai.starwhale.mlops.domain.job.JobType;
import ai.starwhale.mlops.domain.job.bo.Job;
import ai.starwhale.mlops.domain.job.converter.JobConverter;
import ai.starwhale.mlops.domain.job.converter.UserJobConverter;
import ai.starwhale.mlops.domain.job.mapper.JobMapper;
import ai.starwhale.mlops.domain.job.po.JobEntity;
import ai.starwhale.mlops.domain.job.spec.Env;
import ai.starwhale.mlops.domain.job.spec.JobSpecParser;
import ai.starwhale.mlops.domain.job.spec.StepSpec;
import ai.starwhale.mlops.domain.model.ModelDao;
import ai.starwhale.mlops.domain.model.ModelService;
import ai.starwhale.mlops.domain.model.bo.ModelVersion;
import ai.starwhale.mlops.domain.model.po.ModelEntity;
import ai.starwhale.mlops.domain.model.po.ModelVersionEntity;
Expand Down Expand Up @@ -88,6 +92,12 @@ public class FineTuneAppService {

final UserJobConverter userJobConverter;

final JobConverter jobConverter;

final ModelService modelService;

final DatasetService datasetService;

public FineTuneAppService(
FeaturesProperties featuresProperties,
JobCreator jobCreator,
Expand All @@ -99,7 +109,10 @@ public FineTuneAppService(
@Value("${sw.instance-uri}") String instanceUri,
DatasetDao datasetDao,
FineTuneSpaceMapper fineTuneSpaceMapper,
UserJobConverter userJobConverter
UserJobConverter userJobConverter,
JobConverter jobConverter,
ModelService modelService,
DatasetService datasetService
) {
this.featuresProperties = featuresProperties;
this.jobCreator = jobCreator;
Expand All @@ -112,6 +125,9 @@ public FineTuneAppService(
this.instanceUri = instanceUri;
this.fineTuneSpaceMapper = fineTuneSpaceMapper;
this.userJobConverter = userJobConverter;
this.jobConverter = jobConverter;
this.modelService = modelService;
this.datasetService = datasetService;
}


Expand Down Expand Up @@ -200,17 +216,21 @@ public PageInfo<FineTuneVo> list(Long spaceId, Integer pageNum, Integer pageSize
fineTuneEntity.getEvalDatasets();
fineTuneEntity.getTrainDatasets();
fineTuneEntity.getBaseModelVersionId();
fineTuneEntity.getTargetModelVersionId();
ModelVo mv = null;
Long targetModelVersionId = fineTuneEntity.getTargetModelVersionId();
if (null != targetModelVersionId) {
List<ModelVo> modelVos = modelService.findModelByVersionId(List.of(targetModelVersionId));
if (!CollectionUtils.isEmpty(modelVos)) {
mv = modelVos.get(0);
}
}

return FineTuneVo.builder()
.id(fineTuneEntity.getId())
.jobId(jobId)
.status(job.getJobStatus())
.startTime(job.getCreatedTime().getTime())
.endTime(null != job.getFinishedTime() ? job.getFinishedTime().getTime() : null)
.evalDatasets(List.of())//TODO
.trainDatasets(List.of())//TODO
.baseModel(null)//TODO
.targetModel(null)//TODO
.job(jobConverter.convert(job))
.evalDatasets(datasetService.findDatasetsByVersionIds(fineTuneEntity.getTrainDatasets()))
.trainDatasets(datasetService.findDatasetsByVersionIds(fineTuneEntity.getEvalDatasets()))
.targetModel(mv)
.build();
}).collect(Collectors.toList()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

package ai.starwhale.mlops.domain.ft.vo;

import ai.starwhale.mlops.domain.job.status.JobStatus;
import ai.starwhale.mlops.api.protocol.dataset.DatasetVo;
import ai.starwhale.mlops.api.protocol.job.JobVo;
import ai.starwhale.mlops.api.protocol.model.ModelVo;
import java.util.List;
import javax.validation.constraints.NotNull;
import lombok.AllArgsConstructor;
Expand All @@ -33,29 +35,11 @@ public class FineTuneVo {
@NotNull
Long id;
@NotNull
Long jobId;
JobVo job;
@NotNull
JobStatus status;
List<DatasetVo> trainDatasets;
List<DatasetVo> evalDatasets;
@NotNull
Long startTime;
Long endTime;
List<DsInfo> trainDatasets;
List<DsInfo> evalDatasets;
@NotNull
ModelInfo baseModel;
ModelInfo targetModel;

public static class DsInfo {
String name;
String version;
Long id;
}

public static class ModelInfo {
String name;
String version;
Long id;
}

ModelVo targetModel;

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import ai.starwhale.mlops.common.IdConverter;
import ai.starwhale.mlops.configuration.FeaturesProperties;
import ai.starwhale.mlops.domain.dataset.DatasetDao;
import ai.starwhale.mlops.domain.dataset.DatasetService;
import ai.starwhale.mlops.domain.dataset.bo.DatasetVersion;
import ai.starwhale.mlops.domain.ft.mapper.FineTuneMapper;
import ai.starwhale.mlops.domain.ft.mapper.FineTuneSpaceMapper;
Expand All @@ -35,12 +36,14 @@
import ai.starwhale.mlops.domain.job.JobCreator;
import ai.starwhale.mlops.domain.job.bo.Job;
import ai.starwhale.mlops.domain.job.bo.UserJobCreateRequest;
import ai.starwhale.mlops.domain.job.converter.JobConverter;
import ai.starwhale.mlops.domain.job.converter.UserJobConverter;
import ai.starwhale.mlops.domain.job.mapper.JobMapper;
import ai.starwhale.mlops.domain.job.po.JobEntity;
import ai.starwhale.mlops.domain.job.spec.JobSpecParser;
import ai.starwhale.mlops.domain.job.spec.StepSpec;
import ai.starwhale.mlops.domain.model.ModelDao;
import ai.starwhale.mlops.domain.model.ModelService;
import ai.starwhale.mlops.domain.model.po.ModelEntity;
import ai.starwhale.mlops.domain.model.po.ModelVersionEntity;
import ai.starwhale.mlops.domain.project.bo.Project;
Expand Down Expand Up @@ -83,8 +86,8 @@ public void setup() {
jobSpecParser = mock(JobSpecParser.class);
modelDao = mock(ModelDao.class);
datasetDao = mock(DatasetDao.class);
UserJobConverter jobConverter = mock(UserJobConverter.class);
when(jobConverter.convert(any(), any())).thenReturn(UserJobCreateRequest.builder().build());
UserJobConverter userJobConverter = mock(UserJobConverter.class);
when(userJobConverter.convert(any(), any())).thenReturn(UserJobCreateRequest.builder().build());
fineTuneSpaceMapper = mock(FineTuneSpaceMapper.class);
featuresProperties = mock(FeaturesProperties.class);
when(featuresProperties.isFineTuneEnabled()).thenReturn(true);
Expand All @@ -98,7 +101,11 @@ public void setup() {
modelDao,
"instanceuri",
datasetDao,
fineTuneSpaceMapper, jobConverter//todo
fineTuneSpaceMapper,
userJobConverter,
mock(JobConverter.class),
mock(ModelService.class),
mock(DatasetService.class)
);
}

Expand Down

0 comments on commit ee53554

Please sign in to comment.