Skip to content

Improve and rework GPT-tfjs #654

@JulienVig

Description

@JulienVig

Here is a list of potential improvements for gpt-tfjs in Disco:

  • Create a compile method to initialize the optimizer (rather than initializing it when fitDataset is called). This ensures the optimizer state is persisted across multiple calls to fitDataset
  • Implement save and load methods to save and re-use a trained model
  • Rename classes for better clarity and consistency, e.g. multiple classes and functions are called GPT
  • Assess whether we can use tf.CustomCallbackArgs rather than redefining an interface for TrainingCallbacks
  • Assess whenever we can use TFJS' native fitDataset method rather than overriding it with a custom training loop
    -> tfjs only implements Adam while GPT2 uses AdamW. Additionally, the custom optimizer allows having weight decay which is used in the original GPT2.
  • Reading a text file with TF.js only supports reading line by line which is not ideal for LLM inputs, try implementing a file reader chunk by chunk rather than by lines
  • Training with gpt2 has NaN loss after the first epoch step

#656 and #657 should be addressed first

Metadata

Metadata

Assignees

Labels

discojsRelated to Disco.jsreworkCode that needs to be improved

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions