Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/lint-test-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -363,5 +363,5 @@ jobs:
working-directory: ${{ env.lib_node_dir }}
- run: npm ci && npm link ../${{ env.lib_node_dir }} && npm run build
working-directory: ${{ env.server_dir }}
- run: npm ci && npm link ../${{ env.server_dir }} && npm link ../${{ env.lib_node_dir }}
- run: npm ci && npm link ../${{ env.server_dir }} ../${{ env.lib_node_dir }}
- run: npm start -- -t cifar10 -u 1 -e 1
238 changes: 238 additions & 0 deletions discojs/discojs-core/src/aggregator/base.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
import { client, Task, tf, AsyncInformant } from '..'

import { EventEmitter } from 'events'

import { Map, Set } from 'immutable'

export enum AggregationStep {
ADD,
UPDATE,
AGGREGATE
}

export abstract class Base<T> {
/**
* Contains the ids of all active nodes, i.e. members of the aggregation group at
* a given round. It is a subset of all the nodes available in the network.
*/
protected _nodes: Set<client.NodeID>
/**
* Contains the contributions received from active nodes, accessible by node id.
* It defines the effective aggregation group, which is possibly a subset
* of all active nodes, depending on the aggregation scheme.
*/
protected contributions: Map<number, Map<client.NodeID, T>>
/**
* Emits the aggregation event whenever an aggregation step is performed.
* Triggers the resolve of the result promise and the preparation for the
* next aggregation round.
*/
private readonly eventEmitter: EventEmitter

protected informant?: AsyncInformant<T>
/**
* The result promise which, on resolve, will contain the current aggregation result.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

want to say if it's a model (params) or model difference, or could be either?

Copy link
Contributor

@s314cy s314cy Jul 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it can be anything really, since the class is generic, but in the case of the subclasses, it can either be the model weights or model weights difference

currently, I'm pretty sure the weights passed to the client are the entire model, which will go back to model difference in the future polishing PR(s): write docs, support a directed communication graph, re-include DP (includes model diff & clipping), re-write byzantine-robustness, etc.

*/
protected result: Promise<T>
/**
* The current aggregation round, used for assessing whether a contribution is recent enough
* or not.
*/
protected _round = 0

protected _communicationRound = 0

constructor (
public readonly task: Task,
protected _model?: tf.LayersModel,
protected readonly roundCutoff = 0,
public readonly communicationRounds = 1
) {
this.eventEmitter = new EventEmitter()
this.contributions = Map()
this._nodes = Set()

this.result = this.makeResult()

this.eventEmitter.on('aggregation', () => {
this.nextRound()
})
}

/**
* Adds a node's contribution to the aggregator for a given round.
* The contribution will be aggregated during the round's aggregation step.
* @param nodeId The node's id
* @param contribution The node's contribution
* @param round For which round the contribution was made
*/
abstract add (nodeId: client.NodeID, contribution: T, round: number, communicationRound?: number): boolean

/**
* Performs the aggregation step over the received node contributions.
* Must store the aggregation's result in the aggregator's result promise.
*/
abstract aggregate (): void

registerObserver (informant: AsyncInformant<T>): void {
this.informant = informant
}

/**
* Returns whether the given round is recent enough, dependent on the
* aggregator's round cutoff.
* @param round The round
* @returns True if the round is recent enough, false otherwise
*/
isWithinRoundCutoff (round: number): boolean {
return this.round - round <= this.roundCutoff
}

/**
* Logs useful messages during the various aggregation steps.
* @param step The aggregation step
* @param from The node which triggered the logging message
*/
log (step: AggregationStep, from?: client.NodeID): void {
switch (step) {
case AggregationStep.ADD:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add might be maybe misinterpreted as addition potentially, but here you mean more like save or register?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it corresponds to the aggregator's add method and will have a dedicated docstring in #580 so it'll make more sense then :)

console.log(`> Adding contribution from node ${from ?? '"unknown"'} for round (${this.communicationRound}, ${this.round})`)
return
case AggregationStep.UPDATE:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update meaning?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it corresponds to a node overriding its previous contribution with a new one, and will be clearer in #580 with its dedicated docstring

if (from === undefined) {
return
}
console.log(`> Updating contribution from node ${from} for round (${this.communicationRound}, ${this.round})`)
return
case AggregationStep.AGGREGATE:
console.log('*'.repeat(80))
console.log(`Buffer is full. Aggregating weights for round (${this.communicationRound}, ${this.round})\n`)
}
}

setModel (model: tf.LayersModel): void {
this._model = model
}

/**
* Adds a node's id to the set of active nodes. A node represents an active neighbor
* peer/client within the network, whom we are communicating with during this aggregation
* round.
* @param nodeId The node to be added
*/
registerNode (nodeId: client.NodeID): boolean {
if (!this.nodes.has(nodeId)) {
this._nodes = this._nodes.add(nodeId)
return true
}
return false
}

/**
* Overwrites the current set of active nodes with the given one. A node represents
* an active neighbor peer/client within the network, whom we are communicating with
* during this aggregation round.
* @param nodeIds The new set of nodes
*/
setNodes (nodeIds: Set<client.NodeID>): void {
this._nodes = nodeIds
}

resetNodes (): void {
this._nodes = Set()
}

/**
* Sets the aggregator's round number. To be used whenever the aggregator is out of sync
* with the network's round.
* @param round The new round
*/
setRound (round: number): void {
if (round > this.round) {
this._round = round
}
}

/**
* Emits the event containing the aggregation result, which allows the result
* promise to resolve and for the next aggregation round to take place.
* @param aggregated The aggregation result
*/
protected emit (aggregated: T): void {
this.eventEmitter.emit('aggregation', aggregated)
}

/**
* Resets the aggregator's step and prepares it for the next aggregation round.
*/
public nextRound (): void {
if (++this._communicationRound === this.communicationRounds) {
this._communicationRound = 0
this._round++
this.contributions = Map()
}
this.result = this.makeResult()
this.informant?.update()
}

private async makeResult (): Promise<T> {
return await new Promise((resolve) => {
this.eventEmitter.once('aggregation', (w) => {
resolve(w)
})
})
}

/**
* The aggregation result can be awaited upon in an asynchronous fashion, to allow
* for the receipt of contributions while performing other tasks. This function
* gives access to the current aggregation result's promise, which will eventually
* resolve and contain the result of the very next aggregation step, at the
* time of the function call.
* @returns The promise containing the aggregation result
*/
async receiveResult (): Promise<T> {
return await this.result
}

/**
* Constructs the payload sent to other nodes as contribution.
* @param base Object from which the payload is computed
*/
abstract makePayloads (base: T): Map<client.NodeID, T>

abstract isFull (): boolean

/**
* The set of node ids, representing our neighbors within the network.
*/
get nodes (): Set<client.NodeID> {
return this._nodes
}

/**
* The aggregation round.
*/
get round (): number {
return this._round
}

/**
* The aggregator's current size, defined by its amount of contributions.
* The size is bounded by the amount of all active nodes.
*/
get size (): number {
return this.contributions.size
}

/**
* The aggregator's current model.
*/
get model (): tf.LayersModel | undefined {
return this._model
}

get communicationRound (): number {
return this._communicationRound
}
}
27 changes: 27 additions & 0 deletions discojs/discojs-core/src/aggregator/get.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import { aggregator, Task } from '..'

export enum AggregatorChoice {
MEAN,
ROBUST,
SECURE,
BANDIT
}

export function getAggregator (task: Task): aggregator.Aggregator {
const error = new Error('not implemented')
switch (task.trainingInformation.aggregator) {
case AggregatorChoice.MEAN:
return new aggregator.MeanAggregator(task)
case AggregatorChoice.ROBUST:
throw error
case AggregatorChoice.BANDIT:
throw error
case AggregatorChoice.SECURE:
if (task.trainingInformation.scheme !== 'decentralized') {
throw new Error('secure aggregation is currently supported for decentralized only')
}
return new aggregator.SecureAggregator(task)
default:
return new aggregator.MeanAggregator(task)
}
}
11 changes: 11 additions & 0 deletions discojs/discojs-core/src/aggregator/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import { WeightsContainer } from '../weights'
import { Base } from './base'

export { Base as AggregatorBase, AggregationStep } from './base'
export { MeanAggregator } from './mean'
export { RobustAggregator } from './robust'
export { SecureAggregator } from './secure'

export { getAggregator, AggregatorChoice } from './get'

export type Aggregator = Base<WeightsContainer>
Loading