diff --git a/genparse/lm.py b/genparse/lm.py index c3747a1b..a0bb84ad 100644 --- a/genparse/lm.py +++ b/genparse/lm.py @@ -268,7 +268,7 @@ async def p_next(self, xs, top=None): _logp = await self._model.next_token_logprobs(tokens) _logp = _logp.cpu().numpy() if hasattr(_logp, 'cpu') else _logp - _p = np.exp(_logp.cpu().numpy()) + _p = np.exp(_logp) assert top is None return LazyProb(_p, self._encode, self._decode)