easynlp.distillation¶
VanillaKD¶
-
class
easynlp.distillation.distill_dataset.
DistillatoryBaseDataset
(user_defined_parameters: dict, *args, **kwargs)[source]¶ A dataset class for supporting knowledge distillation. This class does not contain methods in
BaseDataset
and only handles arguments that are proprietary to knowledge distillation.Parameters: user_defined_parameters -- The dict of user defined parameters for knowledge distillation.
-
class
easynlp.distillation.distill_application.
DistillatoryBaseApplication
[source]¶ This is the application class for for supporting knowledge distillation.
-
compute_loss
(forward_outputs, label_ids, teacher_logits, **kwargs)[source]¶ Computing the knowledge distillation loss based on teacher logits.
Parameters: - forward_outputs -- the dict of the output tensors of the student model
- label_ids -- the true label ids
- teacher_logits -- the tensor of teacher logits
Returns: the dict of output tensors containing the loss
-
MetaKD¶
-
class
easynlp.distillation.distill_metakd_dataset.
MetakdSentiClassificationDataset
(pretrained_model_name_or_path, data_file, max_seq_length, input_schema, first_sequence, label_name=None, second_sequence=None, label_enumerate_values=None, *args, **kwargs)[source]¶ A dataset class for supporting metakd knowledge distillation. This class is base on
BaseDataset
additional args:genre: the domain of dataset, choosing all domains data when genre is all.
domain_label: a list of domain in the dataset, the domain list of senti datasets is default value.
-
label_enumerate_values
¶ Returns the label enumerate values.
-
convert_single_row_to_example
(row)[source]¶ Convert sample token to indices.Overrides the methods of the parent class as required. :param row: contains sequence and label. :param text_a: the first sequence in row. :param text_b: the second sequence in row if self.second_sequence is true. :param label: label token if self.label_name is true. :param domain: the domain of squence in row. :param weight: the weight calculated after pre-processing.
- Returns: sing example
encoding: an example contains token indices. A dict additional contains: domain_id: the domain id of squence through mapped.
label_ids: the label id of squence through mapped.
sample_weights: same to the weight.
-
-
class
easynlp.distillation.meta_modeling.
MetaTeacherForSequenceClassification
(pretrained_model_name_or_path=None, **kwargs)[source]¶ An application class for supporting meta-teacher learning. :param pretrained_model_name_or_path: the path of model. :param num_labels: the number of labels. :param num_domains: the number of domains.
Example:
```python >>> from easynlp.distillation.meta_modeling import MetaTeacherForSequenceClassification >>> path = "bert-base-uncased" # using huggingface model >>> model = MetaTeacherForSequenceClassification(pretrained_model_name_or_path=path, num_labels=2, num_domains=4)
>>> path = "checkpoint-path" # using self-defined model base on Application >>> model = MetaTeacherForSequenceClassification.from_pretrained(path) ```
-
class
easynlp.distillation.meta_modeling.
MetaStudentForSequenceClassification
(pretrained_model_name_or_path=None, **kwargs)[source]¶ An application class for supporting meta-distillation. args is same to MetaTeacherForSequenceClassification
You can use the checkpoint from MetaTeacherForSequenceClassification to initialize this model. Example:
`python >>> path = "checkpoint-path-from-MetaTeacherForSequenceClassification" >>> model = MetaTeacherForSequenceClassification.from_pretrained(path) `
-
forward
(inputs, is_student=False, distill_stage='all')[source]¶ Pre-trained distillation when distill_stage is "first", return [attentions, sequence_output, domain_content_output]. Downstream task distillation when distill_stage is "second", return [logits]. This approach is to solve the problem that distributed training cannot find the tensor.
-