Skip to content

Commit

Permalink
Remove config from utils
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhangcs committed May 11, 2023
1 parent 2f3c559 commit 32c0f20
Show file tree
Hide file tree
Showing 28 changed files with 92 additions and 70 deletions.
1 change: 1 addition & 0 deletions supar/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def compatible():
import sys
supar = sys.modules[__name__]
if supar.__version__ < '1.2':
sys.modules['supar.utils.config'] = supar.config
sys.modules['supar.utils.transform'].CoNLL = supar.models.dep.biaffine.transform.CoNLL
sys.modules['supar.utils.transform'].Tree = supar.models.const.crf.transform.Tree
sys.modules['supar.parsers'] = supar.models
Expand Down
2 changes: 1 addition & 1 deletion supar/cmds/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from supar.utils import Config
from supar.config import Config
from supar.utils.logging import init_logger, logger
from supar.utils.parallel import get_device_count, get_free_port

Expand Down
5 changes: 3 additions & 2 deletions supar/utils/config.py → supar/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from __future__ import annotations

import argparse
import yaml
import os
from ast import literal_eval
from configparser import ConfigParser
from typing import Any, Dict, Optional, Sequence

import supar
import yaml
from omegaconf import OmegaConf

import supar
from supar.utils.fn import download


Expand Down
2 changes: 1 addition & 1 deletion supar/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from supar.config import Config
from supar.modules import (CharLSTM, ELMoEmbedding, IndependentDropout,
SharedDropout, TransformerEmbedding,
TransformerWordEmbedding, VariationalLSTM)
from supar.modules.transformer import (TransformerEncoder,
TransformerEncoderLayer)
from supar.utils import Config


class Model(nn.Module):
Expand Down
14 changes: 10 additions & 4 deletions supar/models/const/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
from .tt import TetraTaggingConstituencyModel, TetraTaggingConstituencyParser
from .vi import VIConstituencyModel, VIConstituencyParser

__all__ = ['AttachJuxtaposeConstituencyModel', 'AttachJuxtaposeConstituencyParser',
'CRFConstituencyModel', 'CRFConstituencyParser',
'TetraTaggingConstituencyModel', 'TetraTaggingConstituencyParser',
'VIConstituencyModel', 'VIConstituencyParser']
__all__ = [
'AttachJuxtaposeConstituencyModel',
'AttachJuxtaposeConstituencyParser',
'CRFConstituencyModel',
'CRFConstituencyParser',
'TetraTaggingConstituencyModel',
'TetraTaggingConstituencyParser',
'VIConstituencyModel',
'VIConstituencyParser'
]
2 changes: 1 addition & 1 deletion supar/models/const/aj/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

import torch
import torch.nn as nn
from supar.config import Config
from supar.model import Model
from supar.models.const.aj.transform import AttachJuxtaposeTree
from supar.modules import GraphConvolutionalNetwork
from supar.utils import Config
from supar.utils.common import INF
from supar.utils.fn import pad

Expand Down
4 changes: 2 additions & 2 deletions supar/models/const/aj/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from typing import Dict, Iterable, Set, Union

import torch

from supar.config import Config
from supar.models.const.aj.model import AttachJuxtaposeConstituencyModel
from supar.models.const.aj.transform import AttachJuxtaposeTree
from supar.parser import Parser
from supar.utils import Config, Dataset, Embedding
from supar.utils import Dataset, Embedding
from supar.utils.common import BOS, EOS, NUL, PAD, UNK
from supar.utils.field import Field, RawField, SubwordField
from supar.utils.logging import get_logger
Expand Down
2 changes: 1 addition & 1 deletion supar/models/const/crf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import torch
import torch.nn as nn
from supar.config import Config
from supar.model import Model
from supar.modules import MLP, Biaffine
from supar.structs import ConstituencyCRF
from supar.utils import Config


class CRFConstituencyModel(Model):
Expand Down
4 changes: 2 additions & 2 deletions supar/models/const/crf/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from typing import Dict, Iterable, Set, Union

import torch

from supar.config import Config
from supar.models.const.crf.model import CRFConstituencyModel
from supar.models.const.crf.transform import Tree
from supar.parser import Parser
from supar.structs import ConstituencyCRF
from supar.utils import Config, Dataset, Embedding
from supar.utils import Dataset, Embedding
from supar.utils.common import BOS, EOS, PAD, UNK
from supar.utils.field import ChartField, Field, RawField, SubwordField
from supar.utils.logging import get_logger
Expand Down
3 changes: 1 addition & 2 deletions supar/models/const/tt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@

import torch
import torch.nn as nn

from supar.config import Config
from supar.model import Model
from supar.utils import Config
from supar.utils.common import INF


Expand Down
4 changes: 2 additions & 2 deletions supar/models/const/tt/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from typing import Dict, Iterable, Set, Union

import torch

from supar.config import Config
from supar.models.const.tt.model import TetraTaggingConstituencyModel
from supar.models.const.tt.transform import TetraTaggingTree
from supar.parser import Parser
from supar.utils import Config, Dataset, Embedding
from supar.utils import Dataset, Embedding
from supar.utils.common import BOS, EOS, PAD, UNK
from supar.utils.field import Field, RawField, SubwordField
from supar.utils.logging import get_logger
Expand Down
2 changes: 1 addition & 1 deletion supar/models/const/vi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import torch
import torch.nn as nn
from supar.config import Config
from supar.models.const.crf.model import CRFConstituencyModel
from supar.modules import MLP, Biaffine, Triaffine
from supar.structs import ConstituencyCRF, ConstituencyLBP, ConstituencyMFVI
from supar.utils import Config


class VIConstituencyModel(CRFConstituencyModel):
Expand Down
3 changes: 1 addition & 2 deletions supar/models/const/vi/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
from typing import Dict, Iterable, Set, Union

import torch

from supar.config import Config
from supar.models.const.crf.parser import CRFConstituencyParser
from supar.models.const.crf.transform import Tree
from supar.models.const.vi.model import VIConstituencyModel
from supar.utils import Config
from supar.utils.logging import get_logger
from supar.utils.metric import SpanMetric
from supar.utils.transform import Batch
Expand Down
2 changes: 1 addition & 1 deletion supar/models/dep/biaffine/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import torch
import torch.nn as nn
from supar.config import Config
from supar.model import Model
from supar.models.dep.biaffine.transform import CoNLL
from supar.modules import MLP, Biaffine
from supar.structs import DependencyCRF, MatrixTree
from supar.utils import Config
from supar.utils.common import MIN


Expand Down
4 changes: 2 additions & 2 deletions supar/models/dep/biaffine/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from typing import Iterable, Union

import torch

from supar.config import Config
from supar.models.dep.biaffine.model import BiaffineDependencyModel
from supar.models.dep.biaffine.transform import CoNLL
from supar.parser import Parser
from supar.utils import Config, Dataset, Embedding
from supar.utils import Dataset, Embedding
from supar.utils.common import BOS, PAD, UNK
from supar.utils.field import Field, RawField, SubwordField
from supar.utils.fn import ispunct
Expand Down
3 changes: 1 addition & 2 deletions supar/models/dep/crf/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
from typing import Iterable, Union

import torch

from supar.config import Config
from supar.models.dep.biaffine.parser import BiaffineDependencyParser
from supar.models.dep.crf.model import CRFDependencyModel
from supar.structs import DependencyCRF, MatrixTree
from supar.utils import Config
from supar.utils.fn import ispunct
from supar.utils.logging import get_logger
from supar.utils.metric import AttachmentMetric
Expand Down
2 changes: 1 addition & 1 deletion supar/models/dep/crf2o/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import torch
import torch.nn as nn
from supar.config import Config
from supar.models.dep.biaffine.model import BiaffineDependencyModel
from supar.models.dep.biaffine.transform import CoNLL
from supar.modules import MLP, Biaffine, Triaffine
from supar.structs import Dependency2oCRF, MatrixTree
from supar.utils import Config
from supar.utils.common import MIN


Expand Down
4 changes: 2 additions & 2 deletions supar/models/dep/crf2o/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from typing import Iterable, Union

import torch

from supar.config import Config
from supar.models.dep.biaffine.parser import BiaffineDependencyParser
from supar.models.dep.biaffine.transform import CoNLL
from supar.models.dep.crf2o.model import CRF2oDependencyModel
from supar.structs import Dependency2oCRF
from supar.utils import Config, Dataset, Embedding
from supar.utils import Dataset, Embedding
from supar.utils.common import BOS, PAD, UNK
from supar.utils.field import ChartField, Field, RawField, SubwordField
from supar.utils.fn import ispunct
Expand Down
2 changes: 1 addition & 1 deletion supar/models/dep/vi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

import torch
import torch.nn as nn
from supar.config import Config
from supar.models.dep.biaffine.model import BiaffineDependencyModel
from supar.models.dep.biaffine.transform import CoNLL
from supar.modules import MLP, Biaffine, Triaffine
from supar.structs import (DependencyCRF, DependencyLBP, DependencyMFVI,
MatrixTree)
from supar.utils import Config
from supar.utils.common import MIN


Expand Down
3 changes: 1 addition & 2 deletions supar/models/dep/vi/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from typing import Iterable, Union

import torch

from supar.config import Config
from supar.models.dep.biaffine.parser import BiaffineDependencyParser
from supar.models.dep.vi.model import VIDependencyModel
from supar.utils import Config
from supar.utils.fn import ispunct
from supar.utils.logging import get_logger
from supar.utils.metric import AttachmentMetric
Expand Down
2 changes: 1 addition & 1 deletion supar/models/sdp/biaffine/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# -*- coding: utf-8 -*-

import torch.nn as nn
from supar.config import Config
from supar.model import Model
from supar.modules import MLP, Biaffine
from supar.utils import Config


class BiaffineSemanticDependencyModel(Model):
Expand Down
4 changes: 2 additions & 2 deletions supar/models/sdp/biaffine/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from typing import Iterable, Union

import torch

from supar.config import Config
from supar.models.dep.biaffine.transform import CoNLL
from supar.models.sdp.biaffine import BiaffineSemanticDependencyModel
from supar.parser import Parser
from supar.utils import Config, Dataset, Embedding
from supar.utils import Dataset, Embedding
from supar.utils.common import BOS, PAD, UNK
from supar.utils.field import ChartField, Field, RawField, SubwordField
from supar.utils.logging import get_logger
Expand Down
2 changes: 1 addition & 1 deletion supar/models/sdp/vi/model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# -*- coding: utf-8 -*-

import torch.nn as nn
from supar.config import Config
from supar.model import Model
from supar.modules import MLP, Biaffine, Triaffine
from supar.structs import SemanticDependencyLBP, SemanticDependencyMFVI
from supar.utils import Config


class BiaffineSemanticDependencyModel(Model):
Expand Down
3 changes: 1 addition & 2 deletions supar/models/sdp/vi/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
from typing import Iterable, Union

import torch

from supar.config import Config
from supar.models.dep.biaffine.transform import CoNLL
from supar.models.sdp.biaffine.parser import BiaffineSemanticDependencyParser
from supar.models.sdp.vi.model import VISemanticDependencyModel
from supar.utils import Config
from supar.utils.logging import get_logger
from supar.utils.metric import ChartMetric
from supar.utils.transform import Batch
Expand Down
24 changes: 16 additions & 8 deletions supar/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,19 @@
from .transformer import (TransformerDecoder, TransformerEncoder,
TransformerWordEmbedding)

__all__ = ['Biaffine', 'Triaffine',
'IndependentDropout', 'SharedDropout', 'TokenDropout',
'GraphConvolutionalNetwork',
'CharLSTM', 'VariationalLSTM',
'MLP',
'ELMoEmbedding', 'TransformerEmbedding',
'TransformerWordEmbedding',
'TransformerDecoder', 'TransformerEncoder']
__all__ = [
'Biaffine',
'Triaffine',
'IndependentDropout',
'SharedDropout',
'TokenDropout',
'GraphConvolutionalNetwork',
'CharLSTM',
'VariationalLSTM',
'MLP',
'ELMoEmbedding',
'TransformerEmbedding',
'TransformerWordEmbedding',
'TransformerDecoder',
'TransformerEncoder'
]
7 changes: 5 additions & 2 deletions supar/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from torch.optim.lr_scheduler import ExponentialLR, _LRScheduler

import supar
from supar.utils import Config, Dataset
from supar.config import Config
from supar.utils import Dataset
from supar.utils.field import Field
from supar.utils.fn import download, get_rng_state, set_rng_state
from supar.utils.logging import get_logger, init_logger, progress_bar
Expand Down Expand Up @@ -172,10 +173,12 @@ def train(
find_unused_parameters=args.get('find_unused_parameters', True),
static_graph=args.get('static_graph', False))
if args.amp:
from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import fp16_compress_hook
from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import \
fp16_compress_hook
self.model.register_comm_hook(dist.group.WORLD, fp16_compress_hook)
if args.wandb and is_master():
import wandb

# start a new wandb run to track this script
wandb.init(config=args.primitive_config,
project=args.get('project', self.NAME),
Expand Down
30 changes: 16 additions & 14 deletions supar/structs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,19 @@
from .vi import (ConstituencyLBP, ConstituencyMFVI, DependencyLBP,
DependencyMFVI, SemanticDependencyLBP, SemanticDependencyMFVI)

__all__ = ['StructuredDistribution',
'LinearChainCRF',
'SemiMarkovCRF',
'MatrixTree',
'DependencyCRF',
'Dependency2oCRF',
'ConstituencyCRF',
'BiLexicalizedConstituencyCRF',
'DependencyMFVI',
'DependencyLBP',
'ConstituencyMFVI',
'ConstituencyLBP',
'SemanticDependencyMFVI',
'SemanticDependencyLBP', ]
__all__ = [
'StructuredDistribution',
'LinearChainCRF',
'SemiMarkovCRF',
'MatrixTree',
'DependencyCRF',
'Dependency2oCRF',
'ConstituencyCRF',
'BiLexicalizedConstituencyCRF',
'DependencyMFVI',
'DependencyLBP',
'ConstituencyMFVI',
'ConstituencyLBP',
'SemanticDependencyMFVI',
'SemanticDependencyLBP'
]
Loading

0 comments on commit 32c0f20

Please sign in to comment.