-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
42 lines (30 loc) · 1.52 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from urllib.parse import urlparse
from typing import Tuple, Any, Optional
import torch
def accumulate_padding(input_embeds: torch.Tensor, attention_mask: torch.Tensor, padding_side: str = 'right') -> Tuple[torch.Tensor, torch.Tensor]:
assert padding_side in ['right', 'left']
new_input_embeds = torch.empty_like(input_embeds)
new_attention_masks = torch.empty_like(attention_mask)
for i, (embed, mask) in enumerate(zip(input_embeds, attention_mask)):
padding_indices = torch.where(mask == 0)[0]
non_padding_indices = torch.where(mask == 1)[0]
if padding_side == 'left':
new_indices = torch.cat((padding_indices, non_padding_indices), dim=0)
else:
new_indices = torch.cat((non_padding_indices, padding_indices), dim=0)
new_input_embeds[i] = embed.index_select(0, new_indices)
new_attention_masks[i] = mask.index_select(0, new_indices)
return new_input_embeds, new_attention_masks
class torch_dtype:
def __init__(self, dtype: torch.dtype) -> None:
self.dtype = dtype
def __enter__(self) -> Any:
self.dtype_orig = torch.get_default_dtype()
if self.dtype is not None:
torch.set_default_dtype(self.dtype)
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Optional[bool]:
if self.dtype is not None:
torch.set_default_dtype(self.dtype_orig)
def is_remote_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https")