Skip to content

Commit

Permalink
drop stop words (InternLM#1823)
Browse files Browse the repository at this point in the history
* drop stop words

* fix length

* ignore eos only

* fix
  • Loading branch information
grimoire authored Jul 1, 2024
1 parent 6dc8453 commit 78d88d5
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
SeqList = List[SchedulerSequence]
AdapterList = List[SchedulerAdapter]

_EMPTY_TOKEN = np.empty((0, ), dtype=np.int64)


def _raise_exception_on_finish(task: asyncio.Task) -> None:
"""raise exception on finish."""
Expand Down Expand Up @@ -506,13 +508,15 @@ def update_running(self, running: SeqList, next_token_ids: torch.Tensor,
stopped: torch.Tensor):
"""update scheduler."""
next_token_ids = next_token_ids.numpy()
eos_token_id = self.model_config.eos_token_id
for token, msg, stop in zip(next_token_ids, running, stopped):
if msg.status != MessageStatus.RUNNING:
continue
msg.num_new_tokens += 1
update_token = token
if stop:
update_token = np.empty((0, ), dtype=np.int64)
if stop or token in eos_token_id:
update_token = _EMPTY_TOKEN
else:
msg.num_new_tokens += 1
msg.update_token_ids(update_token)
if stop:
msg.status = MessageStatus.STOPPED
Expand Down

0 comments on commit 78d88d5

Please sign in to comment.