jpcorb20 commited on
Commit
a94ac47
·
verified ·
1 Parent(s): 2e0d6f7

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +10 -1
README.md CHANGED
@@ -88,7 +88,7 @@ Researchers should apply responsible AI best practices, including mapping, measu
88
  ### Loading the model locally
89
 
90
  import torch
91
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
92
 
93
  torch.random.manual_seed(0)
94
 
@@ -114,11 +114,20 @@ Researchers should apply responsible AI best practices, including mapping, measu
114
  tokenizer=tokenizer,
115
  )
116
 
 
 
 
 
 
 
 
 
117
  generation_args = {
118
  "max_new_tokens": 500,
119
  "return_full_text": False,
120
  "temperature": 0.0,
121
  "do_sample": False,
 
122
  }
123
 
124
  output = pipe(messages, **generation_args)
 
88
  ### Loading the model locally
89
 
90
  import torch
91
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, StoppingCriteria
92
 
93
  torch.random.manual_seed(0)
94
 
 
114
  tokenizer=tokenizer,
115
  )
116
 
117
+ class EosListStoppingCriteria(StoppingCriteria):
118
+ def __init__(self, eos_sequence = [32007]):
119
+ self.eos_sequence = eos_sequence
120
+
121
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
122
+ last_ids = input_ids[:,-len(self.eos_sequence):].tolist()
123
+ return self.eos_sequence in last_ids
124
+
125
  generation_args = {
126
  "max_new_tokens": 500,
127
  "return_full_text": False,
128
  "temperature": 0.0,
129
  "do_sample": False,
130
+ "generation_kwargs": {"stopping_criteria": EosListStoppingCriteria()}
131
  }
132
 
133
  output = pipe(messages, **generation_args)