Skip to content

Commit

Permalink
[pre-commit.ci] pre-commit autoupdate (#89)
Browse files Browse the repository at this point in the history
* [pre-commit.ci] pre-commit autoupdate

updates:
- [github.com/astral-sh/ruff-pre-commit: v0.6.9 → v0.9.3](astral-sh/ruff-pre-commit@v0.6.9...v0.9.3)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
pre-commit-ci[bot] authored Feb 1, 2025
1 parent 0dd3515 commit ccda70f
Show file tree
Hide file tree
Showing 25 changed files with 94 additions and 94 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.9
rev: v0.9.3
hooks:
- id: ruff
types_or: [ "python", "pyi", "jupyter" ]
Expand Down
20 changes: 10 additions & 10 deletions amp/amp/ambiguous_parsing/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,29 +29,29 @@ def format_report(scores_by_type):
other_in_top_k = score_dict['other_in_top_k']

print(
f'\tpred_top_1_matches_correct: {pred_top_1_matches_correct} / {total} = {safe_divide(pred_top_1_matches_correct, total) *100:.2f}'
f'\tpred_top_1_matches_correct: {pred_top_1_matches_correct} / {total} = {safe_divide(pred_top_1_matches_correct, total) * 100:.2f}'
)
print(
f'\tpred_top_1_matches_other: {pred_top_1_matches_other} / {total} = {safe_divide(pred_top_1_matches_other, total) *100:.2f}'
f'\tpred_top_1_matches_other: {pred_top_1_matches_other} / {total} = {safe_divide(pred_top_1_matches_other, total) * 100:.2f}'
)
print(
f'\tpred_top_1_matches_0: {pred_top_1_matches_0} / {total} = {safe_divide(pred_top_1_matches_0, total) *100:.2f}'
f'\tpred_top_1_matches_0: {pred_top_1_matches_0} / {total} = {safe_divide(pred_top_1_matches_0, total) * 100:.2f}'
)
print(
f'\tpred_top_1_matches_1: {pred_top_1_matches_1} / {total} = {safe_divide(pred_top_1_matches_1, total) *100:.2f}'
f'\tpred_top_1_matches_1: {pred_top_1_matches_1} / {total} = {safe_divide(pred_top_1_matches_1, total) * 100:.2f}'
)

print(
f'\tpred_top_2_matches_other: {pred_top_2_matches_other} / {total} = {safe_divide(pred_top_2_matches_other, total) *100:.2f}'
f'\tpred_top_2_matches_other: {pred_top_2_matches_other} / {total} = {safe_divide(pred_top_2_matches_other, total) * 100:.2f}'
)
print(
f'\tpred_top_2_matches_correct: {pred_top_2_matches_correct} / {total} = {safe_divide(pred_top_2_matches_correct, total) *100:.2f}'
f'\tpred_top_2_matches_correct: {pred_top_2_matches_correct} / {total} = {safe_divide(pred_top_2_matches_correct, total) * 100:.2f}'
)
print(
f'\tcorrect_in_top_k: {correct_in_top_k} / {total} = {safe_divide(correct_in_top_k, total) *100:.2f}'
f'\tcorrect_in_top_k: {correct_in_top_k} / {total} = {safe_divide(correct_in_top_k, total) * 100:.2f}'
)
print(
f'\tother_in_top_k: {other_in_top_k} / {total} = {safe_divide(other_in_top_k, total) *100:.2f}'
f'\tother_in_top_k: {other_in_top_k} / {total} = {safe_divide(other_in_top_k, total) * 100:.2f}'
)
print('=====================================')

Expand Down Expand Up @@ -114,10 +114,10 @@ def get_score_data(test_data, pred_data, test_data_lut, is_fol=False, convert=Tr
scores_by_type[ex_type]['total'] += 1

print(
f'{missing_first} = {missing_first / len(pred_data) * 100 :.2f} are missing a first output'
f'{missing_first} = {missing_first / len(pred_data) * 100:.2f} are missing a first output'
)
print(
f'{missing_second} = {missing_second / len(pred_data) * 100 :.2f} are missing a second output'
f'{missing_second} = {missing_second / len(pred_data) * 100:.2f} are missing a second output'
)

# divide everything by the total
Expand Down
10 changes: 5 additions & 5 deletions amp/amp/ambiguous_parsing/tree/formula.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def order_helper(node, parent_op=None):
# print(f"children of {statements.nodes[node]['name']} are {child_names}")
if len(children) == 0:
return split_name(statements.nodes[node]['name'], atom=True)
op = f" {split_name(statements.nodes[node]['name'])} "
op = f' {split_name(statements.nodes[node]["name"])} '

# don't use parens if the same type of parent and parent is AND or OR
# i.e. instead of ((a AND b) AND c) allow (a AND b AND c)
Expand Down Expand Up @@ -224,7 +224,7 @@ def parse_formula(cls, formula: str) -> 'Formula':
fxn_and_args = atom.split(' ')
fxn = fxn_and_args[0]
args = fxn_and_args[1:]
new_str = f"{fxn}[{','.join(args)}]"
new_str = f'{fxn}[{",".join(args)}]'
formula = re.sub(f'\( {atom} \)', new_str, formula)

# manipulate quantifiers to make parsing easy
Expand Down Expand Up @@ -259,7 +259,7 @@ def render(self, ordered_vars: bool = False) -> str:
new_root = f'quant:{i}'
statements.add_node(f'quant:{i}', name=q)
if i < len(quantifiers) - 1:
statements.add_edge(f'quant:{i}', f'quant:{i+1}')
statements.add_edge(f'quant:{i}', f'quant:{i + 1}')
else:
statements.add_edge(f'quant:{i}', root)

Expand Down Expand Up @@ -295,9 +295,9 @@ def order_helper(node):
if len(children) == 0:
fxn_str = split_fxn_name(statements.nodes[node]['name'])
return fxn_str
op = f"{split_name(statements.nodes[node]['name'])}"
op = f'{split_name(statements.nodes[node]["name"])}'

return f"( {op} {' '.join([order_helper(n) for n in children])} )"
return f'( {op} {" ".join([order_helper(n) for n in children])} )'

seq = order_helper(root)
return seq
Expand Down
4 changes: 2 additions & 2 deletions amp/amp/ambiguous_parsing/tree/tree_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,15 @@ def convert_to_tree(output, do_flatten=False):
new_suffix = sorted(new_suffix)
# pdb.set_trace()
graph.nodes[n]['name'] = (
f"{graph.nodes[n]['name'].split(':')[0]}:{''.join(new_suffix)}"
f'{graph.nodes[n]["name"].split(":")[0]}:{"".join(new_suffix)}'
)
new_child_names.append(re.sub(':', '', graph.nodes[n]['name']))
# new_child_names.append(graph.nodes[n]['name'].split(":")[0])
# remove child variable names
new_child_names = [re.sub(r'\[.*?\]', '', x) for x in new_child_names]
# pdb.set_trace()
graph.nodes[parent]['name'] = (
f"{graph.nodes[parent]['name'].split(':')[0]}:{''.join(new_child_names)}"
f'{graph.nodes[parent]["name"].split(":")[0]}:{"".join(new_child_names)}'
)

return graph
Expand Down
2 changes: 1 addition & 1 deletion bench/run_calflow_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def main():
prompt = prompt_builder.assemble(selected_train_data, dev_datum)
samples.append({'datum': dev_datum, 'prompt': prompt})

logger.info(f"Example prompt:\n{samples[0]['prompt']}")
logger.info(f'Example prompt:\n{samples[0]["prompt"]}')

llm = vllm.LLM(model=args.model_name)
llm_outputs = llm.generate([s['prompt'] for s in samples], sampling_params)
Expand Down
2 changes: 1 addition & 1 deletion bench/run_calflow_genparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def main():
prompt = prompt_builder.assemble(selected_train_data, dev_datum)
samples.append({'datum': dev_datum, 'prompt': prompt})

logger.debug(f"Example prompt:\n{samples[0]['prompt']}")
logger.debug(f'Example prompt:\n{samples[0]["prompt"]}')

batch_llm = BatchVLLM.from_name(args.model_name)

Expand Down
2 changes: 1 addition & 1 deletion bench/run_spider_genparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def main():
n_query = args.n_query

outpath = (
f'{args.exp_name}-{args.inference}-' f'p{args.particles}-b{args.n_beam}-{n_query}'
f'{args.exp_name}-{args.inference}-p{args.particles}-b{args.n_beam}-{n_query}'
)
if args.schema_grammar:
outpath += '-schema'
Expand Down
6 changes: 3 additions & 3 deletions bench/run_spider_poe.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,9 @@ def main():
n_mismatch += 1

print(
f'correct: {n_correct / (i+1):.2f}, '
f'invalid: {n_invalid / (i+1):.2f}, '
f'mismatch: {n_mismatch / (i+1):.2f}'
f'correct: {n_correct / (i + 1):.2f}, '
f'invalid: {n_invalid / (i + 1):.2f}, '
f'mismatch: {n_mismatch / (i + 1):.2f}'
)


Expand Down
12 changes: 6 additions & 6 deletions bench/spider/process_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None):
key = tables_with_alias[alias] + '.' + col
return start_idx + 1, schema.idMap[key]

assert (
default_tables is not None and len(default_tables) > 0
), 'Default tables should not be None or empty'
assert default_tables is not None and len(default_tables) > 0, (
'Default tables should not be None or empty'
)

for alias in default_tables:
table = tables_with_alias[alias]
Expand Down Expand Up @@ -352,9 +352,9 @@ def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=N
not_op = True
idx += 1

assert (
idx < len_ and toks[idx] in WHERE_OPS
), 'Error condition: idx: {}, tok: {}'.format(idx, toks[idx])
assert idx < len_ and toks[idx] in WHERE_OPS, (
'Error condition: idx: {}, tok: {}'.format(idx, toks[idx])
)
op_id = WHERE_OPS.index(toks[idx])
idx += 1
val1 = val2 = None
Expand Down
30 changes: 15 additions & 15 deletions genparse/batch_inference/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,9 @@ def batch_next_token_logprobs(self, particles, is_initial=False):
self.llm_engine.scheduler.schedule()
)

assert (
len(seq_group_metadata_list) == 1
), 'There should only be a single sequence group'
assert len(seq_group_metadata_list) == 1, (
'There should only be a single sequence group'
)

# update particle metadata with scheduler metadata and output
self.particle_metadata.scheduler_outputs = scheduler_outputs
Expand Down Expand Up @@ -227,18 +227,18 @@ def batch_next_token_logprobs(self, particles, is_initial=False):

self.particle_metadata.sequence_ids_by_seq_group = sequence_ids_by_seq_group

assert (
len(logprobs_by_seq_group) == 1
), 'There should only be one sequence group (logprobs)'
assert (
len(logprobs_by_seq_group[0]) == 1
), 'We should only be decoding a single step (logprobs)'
assert (
len(sequence_ids_by_seq_group) == 1
), 'There should only be one sequence group (sequence ids)'
assert (
len(sequence_ids_by_seq_group[0]) == 1
), 'We should only be decoding a single step (sequence ids)'
assert len(logprobs_by_seq_group) == 1, (
'There should only be one sequence group (logprobs)'
)
assert len(logprobs_by_seq_group[0]) == 1, (
'We should only be decoding a single step (logprobs)'
)
assert len(sequence_ids_by_seq_group) == 1, (
'There should only be one sequence group (sequence ids)'
)
assert len(sequence_ids_by_seq_group[0]) == 1, (
'We should only be decoding a single step (sequence ids)'
)

logprobs = logprobs_by_seq_group[0][0]
sequence_ids = sequence_ids_by_seq_group[0][0]
Expand Down
8 changes: 4 additions & 4 deletions genparse/batch_inference/steer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ def __init__(

def twist(self, log_potential):
if self.log_potential == -np.inf:
assert (
log_potential == -np.inf
), 'Potentials φ must satisfy φ(x) = 0 => φ(xy) = 0, forall x,y in V*'
assert log_potential == -np.inf, (
'Potentials φ must satisfy φ(x) = 0 => φ(xy) = 0, forall x,y in V*'
)
self.log_weight = -np.inf
else:
self.log_weight += log_potential - self.log_potential
Expand Down Expand Up @@ -232,7 +232,7 @@ def pretty_print_particles(particles, step_info):
for i, p in enumerate(particles):
print(f'├ Particle {i:3d} `{p.context[-1]}` : {p}')
print(
f"│ Step {step_info['step']:3d} average weight: {step_info['average_weight']:.4f}"
f'│ Step {step_info["step"]:3d} average weight: {step_info["average_weight"]:.4f}'
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def main():
'--output',
type=pathlib.Path,
required=True,
help='Path to an output .pt file where the canonicalizer will be ' 'written.',
help='Path to an output .pt file where the canonicalizer will be written.',
)
args = parser.parse_args()

Expand Down
12 changes: 6 additions & 6 deletions genparse/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,9 @@ def num_rules(self):
@property
def expected_length(self):
"""Computes the expected length of a string using the Expecattion semiring ()"""
assert (
self.R == Float
), 'This method only supports grammars over the Float semiring'
assert self.R == Float, (
'This method only supports grammars over the Float semiring'
)
new_cfg = self.__class__(R=Expectation, S=self.S, V=self.V)
for r in self:
new_cfg.add(
Expand Down Expand Up @@ -294,9 +294,9 @@ def assert_equal(self, other, verbose=False, throw=True):
colors.mark(r in G),
r,
)
assert not throw or Counter(self.rules) == Counter(
other.rules
), f'\n\nhave=\n{str(self)}\nwant=\n{str(other)}'
assert not throw or Counter(self.rules) == Counter(other.rules), (
f'\n\nhave=\n{str(self)}\nwant=\n{str(other)}'
)

def treesum(self, **kwargs):
"Total weight of the start symbol."
Expand Down
4 changes: 2 additions & 2 deletions genparse/experimental/gad.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def update_backward(self, lm1, lm2):

def graphviz(
self,
fmt_edge=lambda x, a, y: f'{html.escape(str(a))}/{y.mass/x.mass:.2g}',
fmt_edge=lambda x, a, y: f'{html.escape(str(a))}/{y.mass / x.mass:.2g}',
# fmt_node=lambda x: ' ',
fmt_node=lambda x: (
# f'{x.mass}/{x._mass:.2g}' if x.mass > 0 else f'{x._mass:.2g}'
Expand Down Expand Up @@ -192,7 +192,7 @@ def graphviz(
if x.children is None:
continue
for a, y in x.children.items():
g.edge(str(f(x)), str(f(y)), label=f'{fmt_edge(x,a,y)}')
g.edge(str(f(x)), str(f(y)), label=f'{fmt_edge(x, a, y)}')
q.append(y)
for x in xs:
if x.children is not None:
Expand Down
12 changes: 6 additions & 6 deletions genparse/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,9 @@ def encode_prompt(self, prompt):
def __call__(self, context):
assert isinstance(context, tuple), '`context` must be explicitly tokenized'
assert set(context) <= self.V, f'OOVs detected: {set(context) - self.V}'
assert (
context[-1] == self.eos
), f'Context must end with eos ({self.eos!r}); got {context = }.'
assert context[-1] == self.eos, (
f'Context must end with eos ({self.eos!r}); got {context = }.'
)
if self.temperature == 1 and self.top_p is None:
return self._model([self._encode[x] for x in context])
else:
Expand All @@ -239,9 +239,9 @@ async def p_next_async(self, context, _logp=None, return_logp=False):
# _logp is provided by the vllm centralized step function

if _logp is None:
assert isinstance(
context, tuple
), 'API change; `context` must be explicitly tokenized'
assert isinstance(context, tuple), (
'API change; `context` must be explicitly tokenized'
)
assert set(context) <= self.V, f'OOVs detected: {set(context) - self.V}'

tokens = [self._encode[x] for x in context]
Expand Down
10 changes: 5 additions & 5 deletions genparse/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,11 @@ def plotly(
hoverinfo='text',
hovertext=resampled_as_data.apply(
lambda row: (
f"Token: {'`<b>'+row['token']+'</b>`' if row['token'] else ''}<br>"
+ f"Context: {row['context_string']}<br>"
+ f"Step {row['step']}; Avg weight = {row['average weight']:4f}<br>"
+ f"Particle {row['particle']}; Weight = {row['weight']:4f}<br>"
+ f"{' ↳ resample_indices particle '+str(row['resample_indices']) if row['resample?'] else ''}"
f'Token: {"`<b>" + row["token"] + "</b>`" if row["token"] else ""}<br>'
+ f'Context: {row["context_string"]}<br>'
+ f'Step {row["step"]}; Avg weight = {row["average weight"]:4f}<br>'
+ f'Particle {row["particle"]}; Weight = {row["weight"]:4f}<br>'
+ f'{" ↳ resample_indices particle " + str(row["resample_indices"]) if row["resample?"] else ""}'
),
axis=1,
),
Expand Down
6 changes: 3 additions & 3 deletions genparse/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def decode_tokenizer_vocab(tokenizer):
for i, t in enumerate(decoded):
tmp[t].append(i)
for x in tmp:
assert (
len(tmp[x]) == 1
), f'surface form {x!r} maps to more than one token> {tmp[x]}'
assert len(tmp[x]) == 1, (
f'surface form {x!r} maps to more than one token> {tmp[x]}'
)

return decoded

Expand Down
4 changes: 2 additions & 2 deletions genparse/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def update(self):

def graphviz(
self,
fmt_edge=lambda x, a, y: f'{html.escape(str(a))}/{y._mass/x._mass:.2g}',
fmt_edge=lambda x, a, y: f'{html.escape(str(a))}/{y._mass / x._mass:.2g}',
# fmt_node=lambda x: ' ',
fmt_node=lambda x: (
f'{x.mass}/{x._mass:.2g}' if x.mass > 0 else f'{x._mass:.2g}'
Expand Down Expand Up @@ -171,7 +171,7 @@ def graphviz(
continue
for a, y in x.active_children.items():
a = y.token if y.token is not None else a
g.edge(str(f(x)), str(f(y)), label=f'{fmt_edge(x,a,y)}')
g.edge(str(f(x)), str(f(y)), label=f'{fmt_edge(x, a, y)}')
q.append(y)
for x in xs:
if x.child_masses is not None:
Expand Down
2 changes: 1 addition & 1 deletion genparse/trace1.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def graphviz(
if x.children is None:
continue
for a, y in x.children.items():
g.edge(str(f(x)), str(f(y)), label=f'{fmt_edge(x,a,y)}')
g.edge(str(f(x)), str(f(y)), label=f'{fmt_edge(x, a, y)}')
q.append(y)
for x in xs:
if x.children is not None:
Expand Down
Loading

0 comments on commit ccda70f

Please sign in to comment.