Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Add ResMLP as alternative to highway in CNNCharacterEmbedding
Browse files Browse the repository at this point in the history
Summary: Highway causes post-quantization error. As an alternative, MLP with residual connection can be used.

Reviewed By: gardenia22

Differential Revision: D36616180

fbshipit-source-id: ef4e4ea39674e7d6cc3096996d5d621a86419651
  • Loading branch information
dogNoWorry authored and facebook-github-bot committed Jun 2, 2022
1 parent 45ea778 commit 0b0489d
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 8 deletions.
13 changes: 12 additions & 1 deletion pytext/config/field_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,22 @@ class DictFeatConfig(ModuleConfig):
use_weights: bool = True


class ConnectionConfig(ConfigBase):
# Config for connection layers in embedding, Highway or Residual MLP
# Default is highway, see https://arxiv.org/abs/1508.06615
connection_type: str = "highway"
# Number of layers
num_layers: int = 0
# Dropout ratio for resmlp, default to pytorch default 0.1, ignored in highway
dropout: float = 0.1


class CharFeatConfig(ModuleConfig):
embed_dim: int = 100
sparse: bool = False
cnn: CNNParams = CNNParams()
highway_layers: int = 0
connection: ConnectionConfig = ConnectionConfig()
highway_layers: Optional[int] = None # kept for backward-compatibility
projection_dim: Optional[int] = None
export_input_names: List[str] = ["char_vals"]
vocab_from_train_data: bool = True
Expand Down
31 changes: 30 additions & 1 deletion pytext/models/embeddings/char_embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import warnings
from typing import List, Optional

import torch
Expand Down Expand Up @@ -45,13 +46,37 @@ def from_config(
if vocab_size is None:
vocab_size = metadata.vocab_size

if config.highway_layers is not None:
warnings.warn(
"Specifying highway_layers is deprecated, use ConnectionConfig instead.",
DeprecationWarning,
)
highway_layers = config.highway_layers
resmlp_layers = 0
resmlp_dropout = 0
else:
if config.connection.connection_type == "highway":
highway_layers = config.connection.num_layers
resmlp_layers = 0
resmlp_dropout = 0
elif config.connection.connection_type == "resmlp":
highway_layers = 0
resmlp_layers = config.connection.num_layers
resmlp_dropout = config.connection.dropout
else:
raise NotImplementedError(
"Connection type should be either 'highway' or 'resmlp'."
)

return cls(
vocab_size,
config.embed_dim,
config.cnn.kernel_num,
config.cnn.kernel_sizes,
config.highway_layers,
highway_layers,
config.projection_dim,
resmlp_layers,
resmlp_dropout,
)

def __init__(
Expand All @@ -62,6 +87,8 @@ def __init__(
kernel_sizes: List[int],
highway_layers: int,
projection_dim: Optional[int],
resmlp_layers: int = 0,
resmlp_dropout: float = 0.1, # default to pytorch default
*args,
**kwargs,
) -> None:
Expand All @@ -78,6 +105,8 @@ def __init__(
kernel_sizes=kernel_sizes,
highway_layers=highway_layers,
projection_dim=projection_dim,
resmlp_layers=resmlp_layers,
resmlp_dropout=resmlp_dropout,
)
log_class_usage(__class__)

Expand Down
12 changes: 6 additions & 6 deletions pytext/models/word_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,12 @@ class ByteModelInput(Model.Config.ModelInput):
@classmethod
def create_embedding(cls, config, tensorizers):
return CharacterEmbedding(
tensorizers["token_bytes"].NUM_BYTES,
config.embedding.embed_dim,
config.embedding.cnn.kernel_num,
config.embedding.cnn.kernel_sizes,
config.embedding.highway_layers,
config.embedding.projection_dim,
num_embeddings=tensorizers["token_bytes"].NUM_BYTES,
embed_dim=config.embedding.embed_dim,
out_channels=config.embedding.cnn.kernel_num,
kernel_sizes=config.embedding.cnn.kernel_sizes,
highway_layers=config.embedding.connection.num_layers,
projection_dim=config.embedding.projection_dim,
)

def vocab_to_export(self, tensorizers):
Expand Down

0 comments on commit 0b0489d

Please sign in to comment.