easynlp.distillation¶
VanillaKD¶
- 
class 
easynlp.distillation.distill_dataset.DistillatoryBaseDataset(user_defined_parameters: dict, *args, **kwargs)[source]¶ A dataset class for supporting knowledge distillation. This class does not contain methods in
BaseDatasetand only handles arguments that are proprietary to knowledge distillation.Parameters: user_defined_parameters -- The dict of user defined parameters for knowledge distillation. 
- 
class 
easynlp.distillation.distill_application.DistillatoryBaseApplication[source]¶ This is the application class for for supporting knowledge distillation.
- 
compute_loss(forward_outputs, label_ids, teacher_logits, **kwargs)[source]¶ Computing the knowledge distillation loss based on teacher logits.
Parameters: - forward_outputs -- the dict of the output tensors of the student model
 - label_ids -- the true label ids
 - teacher_logits -- the tensor of teacher logits
 
Returns: the dict of output tensors containing the loss
- 
 
MetaKD¶
- 
class 
easynlp.distillation.distill_metakd_dataset.MetakdSentiClassificationDataset(pretrained_model_name_or_path, data_file, max_seq_length, input_schema, first_sequence, label_name=None, second_sequence=None, label_enumerate_values=None, *args, **kwargs)[source]¶ A dataset class for supporting metakd knowledge distillation. This class is base on
BaseDatasetadditional args:genre: the domain of dataset, choosing all domains data when genre is all.
domain_label: a list of domain in the dataset, the domain list of senti datasets is default value.
- 
label_enumerate_values¶ Returns the label enumerate values.
- 
convert_single_row_to_example(row)[source]¶ Convert sample token to indices.Overrides the methods of the parent class as required. :param row: contains sequence and label. :param text_a: the first sequence in row. :param text_b: the second sequence in row if self.second_sequence is true. :param label: label token if self.label_name is true. :param domain: the domain of squence in row. :param weight: the weight calculated after pre-processing.
- Returns: sing example
 encoding: an example contains token indices. A dict additional contains: domain_id: the domain id of squence through mapped.
label_ids: the label id of squence through mapped.
sample_weights: same to the weight.
- 
 
- 
class 
easynlp.distillation.meta_modeling.MetaTeacherForSequenceClassification(pretrained_model_name_or_path=None, **kwargs)[source]¶ An application class for supporting meta-teacher learning. :param pretrained_model_name_or_path: the path of model. :param num_labels: the number of labels. :param num_domains: the number of domains.
Example:
```python >>> from easynlp.distillation.meta_modeling import MetaTeacherForSequenceClassification >>> path = "bert-base-uncased" # using huggingface model >>> model = MetaTeacherForSequenceClassification(pretrained_model_name_or_path=path, num_labels=2, num_domains=4)
>>> path = "checkpoint-path" # using self-defined model base on Application >>> model = MetaTeacherForSequenceClassification.from_pretrained(path) ```
- 
class 
easynlp.distillation.meta_modeling.MetaStudentForSequenceClassification(pretrained_model_name_or_path=None, **kwargs)[source]¶ An application class for supporting meta-distillation. args is same to MetaTeacherForSequenceClassification
You can use the checkpoint from MetaTeacherForSequenceClassification to initialize this model. Example:
`python >>> path = "checkpoint-path-from-MetaTeacherForSequenceClassification" >>> model = MetaTeacherForSequenceClassification.from_pretrained(path) `- 
forward(inputs, is_student=False, distill_stage='all')[source]¶ Pre-trained distillation when distill_stage is "first", return [attentions, sequence_output, domain_content_output]. Downstream task distillation when distill_stage is "second", return [logits]. This approach is to solve the problem that distributed training cannot find the tensor.
-