Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

Get rid of IoC and remove unused training services #5567

Merged
merged 1 commit into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions ts/nni_manager/common/component.ts

This file was deleted.

55 changes: 55 additions & 0 deletions ts/nni_manager/common/ioc_shim.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

import assert from 'node:assert/strict';

type AbstractClass = {
name: string;
};

type Class = {
name: string;
new(): any;
};

class IocShimClass {
private singletons: Map<string, any> = new Map();
private snapshots: Map<string, any> = new Map();

public bind(keyClass: AbstractClass, valueClass: Class): void {
const key = keyClass.name;
assert.ok(!this.singletons.has(key));
this.singletons.set(key, new valueClass());
}

public bindInstance(keyClass: AbstractClass, value: any): void {
const key = keyClass.name;
assert.ok(!this.singletons.has(key));
this.singletons.set(key, value);
}

public get<T>(keyClass: AbstractClass): T {
const key = keyClass.name;
assert.ok(this.singletons.has(key));
return this.singletons.get(key);
}

public snapshot(keyClass: AbstractClass): void {
const key = keyClass.name;
const value = this.singletons.get(key);
this.snapshots.set(key, value);
}

public restore(keyClass: AbstractClass): void {
const key = keyClass.name;
const value = this.snapshots.get(key);
this.singletons.set(key, value);
}

// NOTE: for unit test only
public clear(): void {
this.singletons.clear();
}
}

export const IocShim: IocShimClass = new IocShimClass();
26 changes: 0 additions & 26 deletions ts/nni_manager/common/observableTimer.ts

This file was deleted.

18 changes: 9 additions & 9 deletions ts/nni_manager/common/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ import net from 'net';
import path from 'path';
import * as timersPromises from 'timers/promises';
import { Deferred } from 'ts-deferred';
import { Container } from 'typescript-ioc';

import { Database, DataStore } from './datastore';
import globals from './globals';
import { resetGlobals } from './globals/unittest'; // TODO: this file should not contain unittest helpers
import { IocShim } from './ioc_shim';
import { ExperimentConfig, Manager } from './manager';
import { HyperParameters, TrainingService, TrialJobStatus } from './trainingService';

Expand Down Expand Up @@ -132,10 +132,10 @@ function generateParamFileName(hyperParameters: HyperParameters): string {
* Must be paired with `cleanupUnitTest()`.
*/
function prepareUnitTest(): void {
Container.snapshot(Database);
Container.snapshot(DataStore);
Container.snapshot(TrainingService);
Container.snapshot(Manager);
IocShim.snapshot(Database);
IocShim.snapshot(DataStore);
IocShim.snapshot(TrainingService);
IocShim.snapshot(Manager);

resetGlobals();

Expand All @@ -152,10 +152,10 @@ function prepareUnitTest(): void {
* Must be paired with `prepareUnitTest()`.
*/
function cleanupUnitTest(): void {
Container.restore(Manager);
Container.restore(TrainingService);
Container.restore(DataStore);
Container.restore(Database);
IocShim.restore(Manager);
IocShim.restore(TrainingService);
IocShim.restore(DataStore);
IocShim.restore(Database);
}

let cachedIpv4Address: string | null = null;
Expand Down
4 changes: 2 additions & 2 deletions ts/nni_manager/core/nniDataStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import assert from 'assert';
import { Deferred } from 'ts-deferred';

import * as component from '../common/component';
import { IocShim } from 'common/ioc_shim';
import { Database, DataStore, MetricData, MetricDataRecord, MetricType,
TrialJobEvent, TrialJobEventRecord, TrialJobInfo, HyperParameterFormat,
ExportedDataFormat } from '../common/datastore';
Expand All @@ -16,7 +16,7 @@ import { TrialJobDetail, TrialJobStatus } from '../common/trainingService';
import { getDefaultDatabaseDir, mkDirP } from '../common/utils';

class NNIDataStore implements DataStore {
private db: Database = component.get(Database);
private db: Database = IocShim.get(Database);
private log: Logger = getLogger('NNIDataStore');
private initTask!: Deferred<void>;

Expand Down
14 changes: 3 additions & 11 deletions ts/nni_manager/core/nnimanager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import assert from 'assert';
import { ChildProcess, StdioOptions } from 'child_process';
import { Deferred } from 'ts-deferred';
import * as component from '../common/component';
import { IocShim } from 'common/ioc_shim';
import { DataStore, MetricDataRecord, MetricType, TrialJobInfo } from '../common/datastore';
import { NNIError } from '../common/errors';
import { getExperimentId } from '../common/experimentStartupInfo';
Expand Down Expand Up @@ -64,7 +64,7 @@ class NNIManager implements Manager {
this.readonly = false;

this.log = getLogger('NNIManager');
this.dataStore = component.get(DataStore);
this.dataStore = IocShim.get(DataStore);
this.status = {
status: 'INITIALIZED',
errors: []
Expand Down Expand Up @@ -315,11 +315,6 @@ class NNIManager implements Manager {
this.trainingService = new fcModule.FrameworkControllerTrainingService();
break;
}
case 'adl_config': {
const adlModule = await import('../training_service/kubernetes/adl/adlTrainingService');
this.trainingService = new adlModule.AdlTrainingService();
break;
}
default:
throw new Error("Setup training service failed.");
}
Expand Down Expand Up @@ -395,7 +390,7 @@ class NNIManager implements Manager {
this.setStatus('STOPPED');
this.log.info('Experiment stopped.');

await component.get<TensorboardManager>(TensorboardManager).stop();
await IocShim.get<TensorboardManager>(TensorboardManager).stop();
await this.dataStore.close();
}

Expand Down Expand Up @@ -492,9 +487,6 @@ class NNIManager implements Manager {
} else if (platform === 'frameworkcontroller') {
const module_ = await import('../training_service/kubernetes/frameworkcontroller/frameworkcontrollerTrainingService');
return new module_.FrameworkControllerTrainingService();
} else if (platform === 'adl') {
const module_ = await import('../training_service/kubernetes/adl/adlTrainingService');
return new module_.AdlTrainingService();
} else {
this.pollInterval = 0.5;
const module_ = await import('../training_service/v3/compat');
Expand Down
6 changes: 3 additions & 3 deletions ts/nni_manager/extensions/nniTensorboardManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ import cp from 'child_process';
import path from 'path';
import { ChildProcess } from 'child_process';

import * as component from '../common/component';
import { getLogger, Logger } from '../common/log';
import { getTunerProc, isAlive, uniqueString, mkDirPSync, getFreePort } from '../common/utils';
import { Manager } from '../common/manager';
import { TensorboardParams, TensorboardTaskStatus, TensorboardTaskInfo, TensorboardManager } from '../common/tensorboardManager';
import globals from 'common/globals';
import { globals } from 'common/globals';
import { IocShim } from 'common/ioc_shim';

class TensorboardTaskDetail implements TensorboardTaskInfo {
public id: string;
Expand Down Expand Up @@ -39,7 +39,7 @@ class NNITensorboardManager implements TensorboardManager {
this.log = getLogger('NNITensorboardManager');
this.tensorboardTaskMap = new Map<string, TensorboardTaskDetail>();
this.setTensorboardVersion();
this.nniManager = component.get(Manager);
this.nniManager = IocShim.get(Manager);
}

public async startTensorboardTask(tensorboardParams: TensorboardParams): Promise<TensorboardTaskDetail> {
Expand Down
14 changes: 6 additions & 8 deletions ts/nni_manager/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@

import 'app-module-path/register'; // so we can use absolute path to import

import { Container, Scope } from 'typescript-ioc';

import { globals, initGlobals } from 'common/globals';
initGlobals();

import * as component from 'common/component';
import { Database, DataStore } from 'common/datastore';
import { IocShim } from 'common/ioc_shim';
import { Logger, getLogger } from 'common/log';
import { Manager } from 'common/manager';
import { TensorboardManager } from 'common/tensorboardManager';
Expand All @@ -47,12 +45,12 @@ async function start(): Promise<void> {
const restServer = new RestServer(globals.args.port, globals.args.urlPrefix);
await restServer.start();

Container.bind(Manager).to(NNIManager).scope(Scope.Singleton);
Container.bind(Database).to(SqlDB).scope(Scope.Singleton);
Container.bind(DataStore).to(NNIDataStore).scope(Scope.Singleton);
Container.bind(TensorboardManager).to(NNITensorboardManager).scope(Scope.Singleton);
IocShim.bind(Database, SqlDB);
IocShim.bind(DataStore, NNIDataStore);
IocShim.bind(Manager, NNIManager);
IocShim.bind(TensorboardManager, NNITensorboardManager);

const ds: DataStore = component.get(DataStore);
const ds: DataStore = IocShim.get(DataStore);
await ds.init();

globals.rest.registerExpressRouter('/api/v1/nni', createRestHandler());
Expand Down
Loading