easynlp.distillation

VanillaKD

class easynlp.distillation.distill_dataset.DistillatoryBaseDataset(user_defined_parameters: dict, *args, **kwargs)[source]

A dataset class for supporting knowledge distillation. This class does not contain methods in BaseDataset and only handles arguments that are proprietary to knowledge distillation.

Parameters:user_defined_parameters -- The dict of user defined parameters for knowledge distillation.
convert_single_row_to_example(row)[source]

Converting the examples into the dict of values.

load_bin_file()[source]
class easynlp.distillation.distill_application.DistillatoryBaseApplication[source]

This is the application class for for supporting knowledge distillation.

compute_loss(forward_outputs, label_ids, teacher_logits, **kwargs)[source]

Computing the knowledge distillation loss based on teacher logits.

Parameters:
  • forward_outputs -- the dict of the output tensors of the student model
  • label_ids -- the true label ids
  • teacher_logits -- the tensor of teacher logits

Returns: the dict of output tensors containing the loss

class easynlp.core.distiller.DistillatoryTrainer(user_defined_parameters, **kwargs)[source]
train()[source]

MetaKD

class easynlp.distillation.distill_metakd_dataset.MetakdSentiClassificationDataset(pretrained_model_name_or_path, data_file, max_seq_length, input_schema, first_sequence, label_name=None, second_sequence=None, label_enumerate_values=None, *args, **kwargs)[source]

A dataset class for supporting metakd knowledge distillation. This class is base on BaseDataset additional args:

genre: the domain of dataset, choosing all domains data when genre is all.

domain_label: a list of domain in the dataset, the domain list of senti datasets is default value.

label_enumerate_values

Returns the label enumerate values.

readlines_from_file(data_file, skip_first_line=None)[source]
convert_single_row_to_example(row)[source]

Convert sample token to indices.Overrides the methods of the parent class as required. :param row: contains sequence and label. :param text_a: the first sequence in row. :param text_b: the second sequence in row if self.second_sequence is true. :param label: label token if self.label_name is true. :param domain: the domain of squence in row. :param weight: the weight calculated after pre-processing.

Returns: sing example

encoding: an example contains token indices. A dict additional contains: domain_id: the domain id of squence through mapped.

label_ids: the label id of squence through mapped.

sample_weights: same to the weight.

batch_fn(features)[source]

Divide examples into batches.

class easynlp.distillation.meta_modeling.MetaTeacherForSequenceClassification(pretrained_model_name_or_path=None, **kwargs)[source]

An application class for supporting meta-teacher learning. :param pretrained_model_name_or_path: the path of model. :param num_labels: the number of labels. :param num_domains: the number of domains.

Example:

```python >>> from easynlp.distillation.meta_modeling import MetaTeacherForSequenceClassification >>> path = "bert-base-uncased" # using huggingface model >>> model = MetaTeacherForSequenceClassification(pretrained_model_name_or_path=path, num_labels=2, num_domains=4)

>>> path = "checkpoint-path" # using self-defined model base on Application
>>> model = MetaTeacherForSequenceClassification.from_pretrained(path)
```
init_weights()[source]
forward(inputs)[source]

input_ids, attention_mask, token_type_ids is same to class PreTrainedModel domain_ids: the domain id of data. Shape: [, batchsize]

compute_loss(forward_outputs, label_ids, use_domain_loss=True, use_sample_weights=True, **kwargs)[source]
class easynlp.distillation.meta_modeling.MetaStudentForSequenceClassification(pretrained_model_name_or_path=None, **kwargs)[source]

An application class for supporting meta-distillation. args is same to MetaTeacherForSequenceClassification

You can use the checkpoint from MetaTeacherForSequenceClassification to initialize this model. Example:

`python >>> path = "checkpoint-path-from-MetaTeacherForSequenceClassification" >>> model = MetaTeacherForSequenceClassification.from_pretrained(path) `

forward(inputs, is_student=False, distill_stage='all')[source]

Pre-trained distillation when distill_stage is "first", return [attentions, sequence_output, domain_content_output]. Downstream task distillation when distill_stage is "second", return [logits]. This approach is to solve the problem that distributed training cannot find the tensor.

compute_loss(**kwargs)[source]
When distill_stage is first:
The student model will fit the teacher model of [attention, representation, domain].
When distill_stage is second:
The student model will use distillation loss to fit the logits of the teacher model.
class easynlp.core.distiller.MetaTeacherTrainer(model, train_dataset, evaluator, **kwargs)[source]
train()[source]
class easynlp.core.distiller.MetaDistillationTrainer(student_model, teacher_model, train_dataset, evaluator, **kwargs)[source]
set_teacher_model(model)[source]
train()[source]