Skip to content

Commit

Permalink
test(controller): add unit tests (#1222)
Browse files Browse the repository at this point in the history
add unit tests
  • Loading branch information
dreamlandliu authored Sep 16, 2022
1 parent 10fbe73 commit 68216c6
Show file tree
Hide file tree
Showing 10 changed files with 640 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import ai.starwhale.mlops.exception.SwValidationException.ValidSubject;
import ai.starwhale.mlops.exception.api.StarwhaleApiException;
import com.github.pagehelper.PageInfo;
import javax.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
Expand All @@ -47,24 +46,24 @@
@RequestMapping("${sw.controller.apiPrefix}")
public class JobController implements JobApi {

@Resource
private JobService jobService;

@Resource
private TaskService taskService;

@Resource
private IdConvertor idConvertor;

@Resource
private DagQuerier dagQuerier;


private final InvokerManager<String, String> jobActions = InvokerManager.<String, String>create()
.addInvoker("cancel", (String jobUrl) -> jobService.cancelJob(jobUrl))
.addInvoker("pause", (String jobUrl) -> jobService.pauseJob(jobUrl))
.addInvoker("resume", (String jobUrl) -> jobService.resumeJob(jobUrl))
.unmodifiable();
private final JobService jobService;
private final TaskService taskService;
private final IdConvertor idConvertor;
private final DagQuerier dagQuerier;
private final InvokerManager<String, String> jobActions;

public JobController(JobService jobService, TaskService taskService, IdConvertor idConvertor,
DagQuerier dagQuerier) {
this.jobService = jobService;
this.taskService = taskService;
this.idConvertor = idConvertor;
this.dagQuerier = dagQuerier;
this.jobActions = InvokerManager.<String, String>create()
.addInvoker("cancel", jobService::cancelJob)
.addInvoker("pause", jobService::pauseJob)
.addInvoker("resume", jobService::resumeJob)
.unmodifiable();
}

@Override
public ResponseEntity<ResponseMessage<PageInfo<JobVo>>> listJobs(String projectUrl, String swmpId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import ai.starwhale.mlops.domain.system.SystemService;
import com.github.pagehelper.PageInfo;
import java.util.List;
import javax.annotation.Resource;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
Expand All @@ -36,8 +35,11 @@
@RequestMapping("${sw.controller.apiPrefix}")
public class SystemController implements SystemApi {

@Resource
private SystemService systemService;
private final SystemService systemService;

public SystemController(SystemService systemService) {
this.systemService = systemService;
}

@Override
public ResponseEntity<ResponseMessage<PageInfo<AgentVo>>> listAgent(String ip, Integer pageNum,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@
import ai.starwhale.mlops.domain.job.po.BaseImageEntity;
import ai.starwhale.mlops.exception.ConvertException;
import java.util.Objects;
import javax.annotation.Resource;
import org.springframework.stereotype.Component;

@Component
public class BaseImageConvertor implements Convertor<BaseImageEntity, BaseImageVo> {

@Resource
private IdConvertor idConvertor;
private final IdConvertor idConvertor;

public BaseImageConvertor(IdConvertor idConvertor) {
this.idConvertor = idConvertor;
}

@Override
public BaseImageVo convert(BaseImageEntity baseImageEntity) throws ConvertException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import ai.starwhale.mlops.exception.SwValidationException.ValidSubject;
import ai.starwhale.mlops.exception.api.StarwhaleApiException;
import cn.hutool.core.util.StrUtil;
import javax.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Service;
Expand All @@ -33,12 +32,13 @@
@Service
public class JobManager {

@Resource
private JobMapper jobMapper;

@Resource
private IdConvertor idConvertor;
private final JobMapper jobMapper;
private final IdConvertor idConvertor;

public JobManager(JobMapper jobMapper, IdConvertor idConvertor) {
this.jobMapper = jobMapper;
this.idConvertor = idConvertor;
}

public Long getJobId(String jobUrl) {
Job job = fromUrl(jobUrl);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

import ai.starwhale.mlops.domain.swds.po.SwDatasetVersionEntity;
import java.util.List;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;

@Mapper
public interface JobSwdsVersionMapper {

List<SwDatasetVersionEntity> listSwdsVersionsByJobId(@Param("jobId") Long jobId);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
/*
* Copyright 2022 Starwhale, Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.starwhale.mlops.api;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.hasProperty;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.isA;
import static org.hamcrest.Matchers.notNullValue;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyFloat;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;

import ai.starwhale.mlops.api.protocol.job.JobModifyRequest;
import ai.starwhale.mlops.api.protocol.job.JobRequest;
import ai.starwhale.mlops.api.protocol.job.JobVo;
import ai.starwhale.mlops.api.protocol.task.TaskVo;
import ai.starwhale.mlops.common.IdConvertor;
import ai.starwhale.mlops.common.PageParams;
import ai.starwhale.mlops.domain.dag.DagQuerier;
import ai.starwhale.mlops.domain.job.JobService;
import ai.starwhale.mlops.domain.task.TaskService;
import ai.starwhale.mlops.exception.api.StarwhaleApiException;
import com.github.pagehelper.Page;
import java.util.List;
import java.util.Objects;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.http.HttpStatus;

public class JobControllerTest {

private JobController controller;

private JobService jobService;

private TaskService taskService;

private DagQuerier dagQuerier;

@BeforeEach
public void setUp() {
jobService = mock(JobService.class);
taskService = mock(TaskService.class);
dagQuerier = mock(DagQuerier.class);
controller = new JobController(
jobService,
taskService,
new IdConvertor(),
dagQuerier
);
}

@Test
public void testListJobs() {
given(jobService.listJobs(same("p1"), same(1L), any(PageParams.class)))
.willAnswer(invocation -> {
PageParams pageParams = invocation.getArgument(2);
try (Page<JobVo> page = new Page<>(pageParams.getPageNum(), pageParams.getPageSize())) {
page.add(JobVo.builder().build());
return page.toPageInfo();
}
});
var resp = controller.listJobs("p1", "1", 3, 10);
assertThat(resp.getStatusCode(), is(HttpStatus.OK));
assertThat(Objects.requireNonNull(resp.getBody()).getData(), allOf(
notNullValue(),
is(hasProperty("pageNum", is(3))),
is(hasProperty("pageSize", is(10))),
is(hasProperty("list", isA(List.class)))
));
}

@Test
public void testFindJob() {
given(jobService.findJob(same("p1"), same("j1")))
.willReturn(JobVo.builder().id("j1").build());
var resp = controller.findJob("p1", "j1");
assertThat(resp.getStatusCode(), is(HttpStatus.OK));
assertThat(Objects.requireNonNull(resp.getBody()).getData(), allOf(
notNullValue(),
is(hasProperty("id", is("j1")))
));
}

@Test
public void testListTasks() {
given(taskService.listTasks(same("j1"), any(PageParams.class)))
.willAnswer(invocation -> {
PageParams pageParams = invocation.getArgument(1);
try (Page<TaskVo> page = new Page<>(pageParams.getPageNum(), pageParams.getPageSize())) {
page.add(TaskVo.builder().build());
return page.toPageInfo();
}
});
var resp = controller.listTasks("p1", "j1", 3, 5);
assertThat(resp.getStatusCode(), is(HttpStatus.OK));
assertThat(Objects.requireNonNull(resp.getBody()).getData(), allOf(
notNullValue(),
is(hasProperty("pageNum", is(3))),
is(hasProperty("pageSize", is(5))),
is(hasProperty("list", isA(List.class)))
));
}

@Test
public void testCreatJob() {
given(jobService.createJob(anyString(), anyString(),
anyString(), anyString(), anyString(),
anyFloat(), anyString(), anyString()))
.willReturn(1L);
JobRequest jobRequest = new JobRequest();
jobRequest.setComment("");
jobRequest.setDevice("");
jobRequest.setModelVersionUrl("");
jobRequest.setDatasetVersionUrls("");
jobRequest.setRuntimeVersionUrl("");
jobRequest.setResourcePool("");
jobRequest.setDeviceAmount(0f);
var resp = controller.createJob("p1", jobRequest);
assertThat(resp.getStatusCode(), is(HttpStatus.OK));
assertThat(Objects.requireNonNull(resp.getBody()).getData(), allOf(
notNullValue(),
is("1")
));

}

private String invoked = "";

@Test
public void testAction() {
doAnswer(invocation -> invoked = "cancel_" + invocation.getArgument(0))
.when(jobService).cancelJob(anyString());
doAnswer(invocation -> invoked = "pause_" + invocation.getArgument(0))
.when(jobService).pauseJob(anyString());
doAnswer(invocation -> invoked = "resume_" + invocation.getArgument(0))
.when(jobService).resumeJob(anyString());

controller.action("", "job1", "cancel");
assertThat(invoked, is("cancel_job1"));

controller.action("", "job2", "pause");
assertThat(invoked, is("pause_job2"));

controller.action("", "job3", "resume");
assertThat(invoked, is("resume_job3"));

assertThrows(StarwhaleApiException.class,
() -> controller.action("", "job1", null));

assertThrows(StarwhaleApiException.class,
() -> controller.action("", "job1", "a"));
}

@Test
public void testGetJobResult() {
given(jobService.getJobResult(anyString(), anyString()))
.willAnswer(invocation -> "result_" + invocation.getArgument(0)
+ "_" + invocation.getArgument(1));
var resp = controller.getJobResult("project1", "job1");
assertThat(resp.getStatusCode(), is(HttpStatus.OK));
assertThat(Objects.requireNonNull(resp.getBody()).getData(), is("result_project1_job1"));
}

@Test
public void testModifyJobComment() {
given(jobService.updateJobComment(same("p1"), same("j1"), same("comment1")))
.willReturn(true);
JobModifyRequest request = new JobModifyRequest();
request.setComment("comment1");
var resp = controller.modifyJobComment("p1", "j1", request);
assertThat(resp.getStatusCode(), is(HttpStatus.OK));

assertThrows(StarwhaleApiException.class,
() -> controller.modifyJobComment("p1", "j2", request));
}

@Test
public void testRemoveJob() {
given(jobService.removeJob(same("p1"), same("j1")))
.willReturn(true);

var resp = controller.removeJob("p1", "j1");
assertThat(resp.getStatusCode(), is(HttpStatus.OK));

assertThrows(StarwhaleApiException.class,
() -> controller.removeJob("p1", "j2"));
}

@Test
public void testRecoverJob() {
given(jobService.recoverJob(same("p1"), same("j1")))
.willReturn(true);

var resp = controller.recoverJob("p1", "j1");
assertThat(resp.getStatusCode(), is(HttpStatus.OK));

assertThrows(StarwhaleApiException.class,
() -> controller.recoverJob("p1", "j2"));
}
}
Loading

0 comments on commit 68216c6

Please sign in to comment.