Skip to content

Commit

Permalink
Refactored from_json methods of Groth16Proof and Groth16VerifyingKe…
Browse files Browse the repository at this point in the history
…ys classes (keep-starknet-strange#199)

Co-authored-by: casiojapi <[email protected]>
  • Loading branch information
fatalbatros and casiojapi authored Sep 16, 2024
1 parent 66abfdc commit 6135bd6
Showing 1 changed file with 64 additions and 51 deletions.
115 changes: 64 additions & 51 deletions hydra/garaga/starknet/groth16_contract_generator/parsing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,8 @@ def __post_init__(self):
def curve_id(self) -> CurveID:
return self.alpha.curve_id

def from_json(file_path: str | Path) -> "Groth16VerifyingKey":
path = Path(file_path)
def from_dict(data: dict) -> "Groth16VerifyingKey":
try:
with path.open("r") as f:
data = json.load(f)
curve_id = try_guessing_curve_id_from_json(data)
try:
verifying_key = find_item_from_key_patterns(data, ["verifying_key"])
Expand Down Expand Up @@ -237,15 +234,22 @@ def from_json(file_path: str | Path) -> "Groth16VerifyingKey":
for point in find_item_from_key_patterns(g1_points, ["K"])
],
)
except KeyError as e:
raise KeyError(f"The key {e} is missing from the JSON data.")

def from_json(file_path: str | Path) -> "Groth16VerifyingKey":
path = Path(file_path)
try:
with path.open("r") as f:
data = json.load(f)
except FileNotFoundError:
cwd = os.getcwd()
print(f"Current working directory: {cwd}")
print(f"Attempted to access file: {os.path.abspath(file_path)}")
raise FileNotFoundError(f"The file {file_path} was not found.")
except json.JSONDecodeError:
raise ValueError(f"The file {file_path} does not contain valid JSON.")
except KeyError as e:
raise KeyError(f"The key {e} is missing from the JSON data.")
return Groth16VerifyingKey.from_dict(data)

def serialize_to_cairo(self) -> str:
# Precompute M = miller_loop(public_pair)
Expand Down Expand Up @@ -308,64 +312,73 @@ def __post_init__(self):
), f"All points must be on the same curve, got {self.a.curve_id}, {self.b.curve_id}, {self.c.curve_id}"
self.curve_id = self.a.curve_id

def from_dict(
data: dict, public_inputs: None | list | dict = None
) -> "Groth16Proof":
curve_id = try_guessing_curve_id_from_json(data)
try:
proof = find_item_from_key_patterns(data, ["proof"])
except ValueError:
proof = data

try:
seal = io.to_hex_str(find_item_from_key_patterns(data, ["seal"]))
image_id = io.to_hex_str(find_item_from_key_patterns(data, ["image_id"]))
journal = io.to_hex_str(find_item_from_key_patterns(data, ["journal"]))

return Groth16Proof._from_risc0(
seal=bytes.fromhex(seal[2:]),
image_id=bytes.fromhex(image_id[2:]),
journal=bytes.fromhex(journal[2:]),
)
except ValueError:
pass
except KeyError:
pass
except Exception as e:
print(f"Error: {e}")
raise e

if public_inputs is not None:
if isinstance(public_inputs, dict):
public_inputs = list(public_inputs.values())
elif isinstance(public_inputs, list):
pass
else:
raise ValueError(f"Invalid public inputs format: {public_inputs}")
else:
public_inputs = find_item_from_key_patterns(data, ["public"])
return Groth16Proof(
a=try_parse_g1_point_from_key(proof, ["a"], curve_id),
b=try_parse_g2_point_from_key(proof, ["b"], curve_id),
c=try_parse_g1_point_from_key(proof, ["c", "Krs"], curve_id),
public_inputs=[io.to_int(pub) for pub in public_inputs],
)

def from_json(
proof_path: str | Path, public_inputs_path: str | Path = None
) -> "Groth16Proof":
path = Path(proof_path)
try:
with path.open("r") as f:
data = json.load(f)
# print(f"data: {data}")
# print(f"data.keys(): {data.keys()}")
curve_id = try_guessing_curve_id_from_json(data)

try:
proof = find_item_from_key_patterns(data, ["proof"])
except ValueError:
proof = data

try:
seal = io.to_hex_str(find_item_from_key_patterns(data, ["seal"]))
image_id = io.to_hex_str(
find_item_from_key_patterns(data, ["image_id"])
)
journal = io.to_hex_str(find_item_from_key_patterns(data, ["journal"]))

return Groth16Proof._from_risc0(
seal=bytes.fromhex(seal[2:]),
image_id=bytes.fromhex(image_id[2:]),
journal=bytes.fromhex(journal[2:]),
)
except ValueError:
pass
except KeyError:
pass
except Exception as e:
print(f"Error: {e}")
raise e

except FileNotFoundError:
raise FileNotFoundError(f"The file {proof_path} was not found.")
except json.JSONDecodeError:
raise ValueError(f"The file {proof_path} does not contain valid JSON.")
try:
if public_inputs_path is not None:
with Path(public_inputs_path).open("r") as f:
public_inputs = json.load(f)
print(f"public_inputs: {public_inputs}")
if isinstance(public_inputs, dict):
public_inputs = list(public_inputs.values())
elif isinstance(public_inputs, list):
pass
else:
raise ValueError(f"Invalid public inputs format: {public_inputs}")
else:
public_inputs = find_item_from_key_patterns(data, ["public"])
return Groth16Proof(
a=try_parse_g1_point_from_key(proof, ["a"], curve_id),
b=try_parse_g2_point_from_key(proof, ["b"], curve_id),
c=try_parse_g1_point_from_key(proof, ["c", "Krs"], curve_id),
public_inputs=[io.to_int(pub) for pub in public_inputs],
)
public_inputs = None
except FileNotFoundError:
raise FileNotFoundError(f"The file {proof_path} was not found.")
raise FileNotFoundError(f"The file {public_inputs_path} was not found.")
except json.JSONDecodeError:
raise ValueError(f"The file {proof_path} does not contain valid JSON.")
raise ValueError(
f"The file {public_inputs_path} does not contain valid JSON."
)
return Groth16Proof.from_dict(data, public_inputs)

def _from_risc0(
seal: bytes,
Expand Down

0 comments on commit 6135bd6

Please sign in to comment.