[go: up one dir, main page]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Biencoder and polyencoder architecture #175

Merged
merged 13 commits into from
Aug 20, 2024
54 changes: 54 additions & 0 deletions configs/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Model Configuration
model_name: microsoft/deberta-v3-small # Hugging Face model
labels_encoder: "BAAI/bge-small-en-v1.5"
name: "span level gliner"
max_width: 12
hidden_size: 768
dropout: 0.3
fine_tune: true
subtoken_pooling: first
fuse_layers: false
post_fusion_schema: "l2l-l2t-t2t"
span_mode: markerV0

# Training Parameters
num_steps: 100000
train_batch_size: 8
eval_every: 5000
warmup_ratio: 0.05
scheduler_type: "cosine"

# loss function
loss_alpha: 0.75
loss_gamma: 0
label_smoothing: 0
loss_reduction: "sum"

# Learning Rate and weight decay Configuration
lr_encoder: 1e-5
lr_others: 3e-5
weight_decay_encoder: 0.1
weight_decay_other: 0.01

max_grad_norm: 10.0

# Directory Paths
root_dir: gliner_logs
train_data: "data.json" #"data/nuner_train.json" # see https://github.com/urchade/GLiNER/tree/main/data
val_data_dir: "none"
# "NER_datasets": val data from the paper can be obtained from "https://drive.google.com/file/d/1T-5IbocGka35I7X3CE6yKe5N_Xg2lVKT/view"

# Pretrained Model Path
# Use "none" if no pretrained model is being used
prev_path: null

save_total_limit: 3 #maximum amount of checkpoints to save

# Advanced Training Settings
size_sup: -1
max_types: 100
shuffle_types: true
random_drop: true
max_neg_type_ratio: 1
max_len: 512
freeze_token_rep: false
19 changes: 11 additions & 8 deletions config.yaml → configs/config_biencoder.yaml
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
# Model Configuration
model_name: microsoft/deberta-v3-small # Hugging Face model
model_name: microsoft/deberta-v3-small # Hugging Face model
labels_encoder: "microsoft/deberta-v3-small"
name: "span level gliner"
max_width: 12
hidden_size: 768
dropout: 0.4
fine_tune: true
subtoken_pooling: first
fuse_layers: false
post_fusion_schema: ""
span_mode: markerV0

# Training Parameters
num_steps: 30000
train_batch_size: 3
eval_every: 3000
train_batch_size: 8
eval_every: 1000
warmup_ratio: 0.1
scheduler_type: "cosine"

Expand All @@ -27,25 +30,25 @@ lr_others: 5e-5
weight_decay_encoder: 0.01
weight_decay_other: 0.01

max_grad_norm: 1.0
max_grad_norm: 10.0

# Directory Paths
root_dir: gliner_logs
train_data: "data.json" # see https://github.com/urchade/GLiNER/tree/main/data
train_data: "data.json" #"data/nuner_train.json" # see https://github.com/urchade/GLiNER/tree/main/data
val_data_dir: "none"
# "NER_datasets": val data from the paper can be obtained from "https://drive.google.com/file/d/1T-5IbocGka35I7X3CE6yKe5N_Xg2lVKT/view"

# Pretrained Model Path
# Use "none" if no pretrained model is being used
prev_path: "none"
prev_path: null

save_total_limit: 10 #maximum amount of checkpoints to save
save_total_limit: 3 #maximum amount of checkpoints to save

# Advanced Training Settings
size_sup: -1
max_types: 25
shuffle_types: true
random_drop: true
max_neg_type_ratio: 1
max_len: 384
max_len: 386
freeze_token_rep: false
File renamed without changes.
File renamed without changes.
66 changes: 23 additions & 43 deletions custom_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from transformers import AutoTokenizer

from gliner import GLiNER, GLiNERConfig
from gliner.data_processing import GLiNERDataset, SpanProcessor, TokenProcessor
from gliner.data_processing import SpanProcessor, TokenProcessor, SpanBiEncoderProcessor, TokenBiEncoderProcessor
from gliner.data_processing.tokenizer import WordsSplitter
from gliner.data_processing.collator import DataCollatorWithPadding, DataCollator
from gliner.utils import load_config_as_namespace
Expand Down Expand Up @@ -79,36 +79,14 @@ def __init__(self, config, allow_distributed, compile_model=False, device='cuda'

self.device = device

self.model_config = GLiNERConfig(
model_name=config.model_name,
name=config.name,
max_width=config.max_width,
hidden_size=config.hidden_size,
dropout=config.dropout,
fine_tune=config.fine_tune,
subtoken_pooling=config.subtoken_pooling,
span_mode=config.span_mode,
loss_alpha=config.loss_alpha,
loss_gamma=config.loss_gamma,
label_smoothing=config.label_smoothing,
loss_reduction=config.loss_reduction,
max_types=config.max_types,
shuffle_types=config.shuffle_types,
random_drop=config.random_drop,
max_neg_type_ratio=config.max_neg_type_ratio,
max_len=config.max_len,
)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
self.model_config.class_token_index=len(tokenizer)
tokenizer.add_tokens([self.model_config.ent_token, self.model_config.sep_token])
self.model_config.vocab_size = len(tokenizer)

words_splitter = WordsSplitter()
self.model_config = GLiNERConfig(**vars(config))

if config.span_mode == "token_level":
self.data_processor = TokenProcessor(self.model_config, tokenizer, words_splitter, preprocess_text=True)
else:
self.data_processor = SpanProcessor(self.model_config, tokenizer, words_splitter, preprocess_text=True)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)

if config.labels_encoder is None:
self.model_config.class_token_index=len(tokenizer)
tokenizer.add_tokens([self.model_config.ent_token, self.model_config.sep_token])
self.model_config.vocab_size = len(tokenizer)

self.allow_distributed = allow_distributed

Expand Down Expand Up @@ -190,17 +168,19 @@ def create_optimizer(self, opt_model, **optimizer_kwargs):
def setup_model_and_optimizer(self, rank=None, device=None):
if device is None:
device = self.device
if self.config.prev_path != "none":
model = GLiNER.from_pretrained(self.config.prev_path, data_processor=self.data_processor).to(device)
if self.config.prev_path is not None:
model = GLiNER.from_pretrained(self.config.prev_path).to(device)
model.config = self.model_config
else:
model = GLiNER(self.model_config, data_processor=self.data_processor).to(device)
model.resize_token_embeddings([self.model_config.ent_token, self.model_config.sep_token],
set_class_token_index = False,
add_tokens_to_tokenizer=False)
model = GLiNER(self.model_config).to(device)
if self.config.labels_encoder is None:
model.resize_token_embeddings([self.model_config.ent_token, self.model_config.sep_token],
set_class_token_index = False,
add_tokens_to_tokenizer=False)
if rank is not None:
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)
model.module.resize_token_embeddings([self.model_config.ent_token, self.model_config.sep_token],
if self.config.labels_encoder is None:
model.module.resize_token_embeddings([self.model_config.ent_token, self.model_config.sep_token],
set_class_token_index = False,
add_tokens_to_tokenizer=False)
optimizer = self.create_optimizer(model.model)
Expand All @@ -210,10 +190,10 @@ def setup_model_and_optimizer(self, rank=None, device=None):

return model, optimizer

def create_dataloader(self, dataset, sampler=None, shuffle=True):
def create_dataloader(self, dataset, data_processor, sampler=None, shuffle=True):
# dataset = GLiNERDataset(dataset, config = self.config, data_processor=self.data_processor)
# collator = DataCollatorWithPadding(self.config)
collator = DataCollator(self.config, data_processor=self.data_processor, prepare_labels=True)
collator = DataCollator(self.config, data_processor=data_processor, prepare_labels=True)
data_loader = DataLoader(dataset, batch_size=self.config.train_batch_size, num_workers=12,
shuffle=shuffle, collate_fn=collator, sampler=sampler)
return data_loader
Expand All @@ -228,7 +208,7 @@ def train_dist(self, rank, world_size, dataset):

sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False)

train_loader = self.create_dataloader(dataset, sampler=sampler, shuffle=False)
train_loader = self.create_dataloader(dataset, model.data_processor, sampler=sampler, shuffle=False)

num_steps = self.config.num_steps // world_size

Expand Down Expand Up @@ -347,14 +327,14 @@ def run(self):
else:
model, optimizer = self.setup_model_and_optimizer()

train_loader = self.create_dataloader(data, shuffle=True)
train_loader = self.create_dataloader(data, model.data_processor, shuffle=True)

self.train(model, optimizer, train_loader, num_steps=self.config.num_steps, device=self.device)


def create_parser():
parser = argparse.ArgumentParser(description="Span-based NER")
parser.add_argument("--config", type=str, default="config.yaml", help="Path to config file")
parser.add_argument("--config", type=str, default="configs/config.yaml", help="Path to config file")
parser.add_argument('--log_dir', type=str, default='logs', help='Path to the log directory')
parser.add_argument('--allow_distributed', type=bool, default=False,
help='Whether to allow distributed training if there are more than one GPU available')
Expand All @@ -372,4 +352,4 @@ def create_parser():
trainer = Trainer(config, allow_distributed=args.allow_distributed,
compile_model = args.compile_model,
device='cuda' if torch.cuda.is_available() else 'cpu')
trainer.run()
trainer.run()
2 changes: 1 addition & 1 deletion gliner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.2.8"
__version__ = "0.2.10.dev"

from .model import GLiNER
from .config import GLiNERConfig
Expand Down
15 changes: 15 additions & 0 deletions gliner/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,25 @@ class GLiNERConfig(PretrainedConfig):
is_composition = True
def __init__(self,
model_name: str = "microsoft/deberta-v3-small",
labels_encoder: str = None,
name: str = "span level gliner",
max_width: int = 12,
hidden_size: int = 512,
dropout: float = 0.4,
fine_tune: bool = True,
subtoken_pooling: str = "first",
span_mode: str = "markerV0",
post_fusion_schema: str = '', #l2l-l2t-t2t
vocab_size: int = -1,
max_neg_type_ratio: int = 1,
max_types: int = 25,
max_len: int = 384,
words_splitter_type: str = "whitespace",
has_rnn: bool = True,
fuse_layers: bool = False,
class_token_index: int = -1,
encoder_config: Optional[dict] = None,
labels_encoder_config: Optional[dict] = None,
ent_token = "<<ENT>>",
sep_token = "<<SEP>>",
**kwargs):
Expand All @@ -32,20 +36,31 @@ def __init__(self,
else "deberta-v2")
encoder_config = CONFIG_MAPPING[encoder_config["model_type"]](**encoder_config)
self.encoder_config = encoder_config

if isinstance(labels_encoder_config, dict):
labels_encoder_config["model_type"] = (labels_encoder_config["model_type"]
if "model_type" in labels_encoder_config
else "deberta-v2")
labels_encoder_config = CONFIG_MAPPING[labels_encoder_config["model_type"]](**labels_encoder_config)
self.labels_encoder_config = labels_encoder_config

self.model_name = model_name
self.labels_encoder = labels_encoder
self.name = name
self.max_width = max_width
self.hidden_size = hidden_size
self.dropout = dropout
self.fine_tune = fine_tune
self.subtoken_pooling = subtoken_pooling
self.span_mode = span_mode
self.post_fusion_schema = post_fusion_schema
self.vocab_size = vocab_size
self.max_neg_type_ratio = max_neg_type_ratio
self.max_types = max_types
self.max_len = max_len
self.words_splitter_type = words_splitter_type
self.has_rnn = has_rnn
self.fuse_layers = fuse_layers
self.class_token_index = class_token_index
self.ent_token = ent_token
self.sep_token = sep_token
Expand Down
2 changes: 1 addition & 1 deletion gliner/data_processing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .processor import SpanProcessor, TokenProcessor
from .processor import SpanProcessor, SpanBiEncoderProcessor, TokenProcessor, TokenBiEncoderProcessor
from .collator import DataCollator
from .tokenizer import WordsSplitter
from .dataset import GLiNERDataset
3 changes: 2 additions & 1 deletion gliner/data_processing/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __call__(self, input_x):
model_input['id_to_classes'] = raw_batch['id_to_classes']
if self.return_entities:
model_input['entities'] = raw_batch['entities']
model_input = {k:v for k, v in model_input.items() if v is not None}
return model_input

class DataCollatorWithPadding:
Expand Down Expand Up @@ -91,7 +92,7 @@ def __call__(self, batch):
padded_batch[key] = torch.tensor(key_data, dtype=torch.float32).to(self.device)
else:
raise TypeError(f"Unsupported data type for key '{key}': {type(key_data[0])}")

padded_batch = {k:v for k,v in padded_batch.items() if v is not None}
return padded_batch

def _pad_2d_tensor(self, key_data):
Expand Down
Loading