Skip to content

Commit

Permalink
[chore] add better type hints and potential mismatches found with myp…
Browse files Browse the repository at this point in the history
…y. I will not add mypy in pre-commit though because it's too many false positives
  • Loading branch information
mangiucugna committed May 26, 2024
1 parent 81580dc commit 5970b37
Showing 1 changed file with 73 additions and 60 deletions.
133 changes: 73 additions & 60 deletions src/json_repair/json_repair.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,37 +24,79 @@

import os
import json
from typing import Any, Dict, List, Union, TextIO
from typing import Any, Dict, List, Literal, Optional, Union, TextIO, Tuple, TypeAlias


class StringFileWrapper:
# This is a trick to simplify the code, transform the filedescriptor handling into a string handling
def __init__(self, fd: TextIO) -> None:
self.fd = fd
self.length: int = 0

def __getitem__(self, index: int) -> str:
if isinstance(index, slice):
self.fd.seek(index.start)
value = self.fd.read(index.stop - index.start)
self.fd.seek(index.start)
return value
else:
self.fd.seek(index)
return self.fd.read(1)

def __len__(self) -> int:
if self.length < 1:
current_position = self.fd.tell()
self.fd.seek(0, os.SEEK_END)
self.length = self.fd.tell()
self.fd.seek(current_position)
return self.length

def __setitem__(self) -> None:
raise Exception("This is read-only!")


class LoggerConfig:
# This is a type class to simplify the declaration
def __init__(self, log_level: Optional[str]):
self.log: List[Dict[str, str]] = []
self.window: int = 10
self.log_level: str = log_level if log_level else "none"


JSONReturnType: TypeAlias = Union[
Dict[str, Any], List[Any], str, float, int, bool, None
]


class JSONParser:
def __init__(self, json_str: str, json_fd: TextIO, logging: bool = False) -> None:
def __init__(
self,
json_str: str | StringFileWrapper,
json_fd: Optional[TextIO],
logging: Optional[bool],
) -> None:
# The string to parse
self.json_str = json_str
# Alternatively, the file description with a json file in it
if json_fd:
# This is a trick we do to treat the file wrapper as an array
self.json_str = StringFileWrapper(json_fd)
# Index is our iterator that will keep track of which character we are looking at right now
self.index = 0
self.index: int = 0
# This is used in the object member parsing to manage the special cases of missing quotes in key or value
self.context = []
self.context: list[str] = []
# Use this to log the activity, but only if logging is active
self.logger = {
"log": [],
"window": 10,
"log_level": "info" if logging else "none",
}

def parse(self) -> Union[Dict[str, Any], List[Any], str, float, int, bool, None]:
if self.logger["log_level"] == "none":
self.logger = LoggerConfig(log_level="info" if logging else None)

def parse(self) -> JSONReturnType | Tuple[JSONReturnType, List[Dict[str, str]]]:
if self.logger.log_level == "none":
return self.parse_json()
else:
return self.parse_json(), self.logger["log"]
return self.parse_json(), self.logger.log

def parse_json(
self,
) -> Union[Dict[str, Any], List[Any], str, float, int, bool, None]:
) -> JSONReturnType:
char = self.get_char_at()
# False means that we are at the end of the string provided, is the base case for recursion
if char is False:
Expand Down Expand Up @@ -225,7 +267,7 @@ def parse_array(self) -> List[Any]:
self.reset_context()
return arr

def parse_string(self) -> str:
def parse_string(self) -> str | JSONReturnType:
# <string> is a string of valid characters enclosed in quotes
# i.e. { name: "John" }
# Somehow all weird cases in an invalid JSON happen to be resolved in this function, so be careful here
Expand Down Expand Up @@ -324,7 +366,7 @@ def parse_string(self) -> str:
string_acc = string_acc[:-1]
if char in [rstring_delimiter, "t", "n", "r", "b", "\\"]:
escape_seqs = {"t": "\t", "n": "\n", "r": "\r", "b": "\b"}
string_acc += escape_seqs.get(char, char)
string_acc += escape_seqs.get(char, char) or char
self.index += 1
char = self.get_char_at()
# ChatGPT sometimes forget to quote stuff in html tags or markdown, so we do this whole thing here
Expand Down Expand Up @@ -418,7 +460,7 @@ def parse_string(self) -> str:

return string_acc.rstrip()

def parse_number(self) -> Union[float, int, str]:
def parse_number(self) -> float | int | str | JSONReturnType:
# <number> is a valid real number expressed in one of a number of given formats
number_str = ""
number_chars = set("0123456789-.eE/,")
Expand Down Expand Up @@ -451,7 +493,6 @@ def parse_number(self) -> Union[float, int, str]:
def parse_boolean_or_null(self) -> Union[bool, str, None]:
# <boolean> is one of the literal strings 'true', 'false', or 'null' (unquoted)
starting_index = self.index
value = ""
char = self.get_char_at().lower()
if char == "t":
value = ("true", True)
Expand All @@ -460,7 +501,7 @@ def parse_boolean_or_null(self) -> Union[bool, str, None]:
elif char == "n":
value = ("null", None)

if len(value):
if value:
i = 0
while char and i < len(value[0]) and char == value[0][i]:
i += 1
Expand All @@ -473,7 +514,7 @@ def parse_boolean_or_null(self) -> Union[bool, str, None]:
self.index = starting_index
return ""

def get_char_at(self, count: int = 0) -> Union[str, bool]:
def get_char_at(self, count: int = 0) -> Union[str, Literal[False]]:
# Why not use something simpler? Because try/except in python is a faster alternative to an "if" statement that is often True
try:
return self.json_str[self.index + count]
Expand Down Expand Up @@ -513,12 +554,12 @@ def get_context(self) -> str:
return ""

def log(self, text: str, level: str) -> None:
if level == self.logger["log_level"]:
if level == self.logger.log_level:
context = ""
start = max(self.index - self.logger["window"], 0)
end = min(self.index + self.logger["window"], len(self.json_str))
start = max(self.index - self.logger.window, 0)
end = min(self.index + self.logger.window, len(self.json_str))
context = self.json_str[start:end]
self.logger["log"].append(
self.logger.log.append(
{
"text": text,
"context": context,
Expand All @@ -528,11 +569,11 @@ def log(self, text: str, level: str) -> None:

def repair_json(
json_str: str = "",
return_objects: bool = False,
skip_json_loads: bool = False,
logging: bool = False,
json_fd: TextIO = None,
) -> Union[Dict[str, Any], List[Any], str, float, int, bool, None]:
return_objects: Optional[bool] = False,
skip_json_loads: Optional[bool] = False,
logging: Optional[bool] = False,
json_fd: Optional[TextIO] = None,
) -> JSONReturnType | Tuple[JSONReturnType, List[Dict[str, str]]]:
"""
Given a json formatted string, it will try to decode it and, if it fails, it will try to fix it.
It will return the fixed string by default.
Expand All @@ -559,7 +600,7 @@ def repair_json(

def loads(
json_str: str, skip_json_loads: bool = False, logging: bool = False
) -> Union[Dict[str, Any], List[Any], str, float, int, bool, None]:
) -> JSONReturnType | Tuple[JSONReturnType, List[Dict[str, str]]]:
"""
This function works like `json.loads()` except that it will fix your JSON in the process.
It is a wrapper around the `repair_json()` function with `return_objects=True`.
Expand All @@ -574,7 +615,7 @@ def loads(

def load(
fd: TextIO, skip_json_loads: bool = False, logging: bool = False
) -> Union[Dict[str, Any], List[Any], str, float, int, bool, None]:
) -> JSONReturnType | Tuple[JSONReturnType, List[Dict[str, str]]]:
"""
This function works like `json.load()` except that it will fix your JSON in the process.
It is a wrapper around the `repair_json()` function with `json_fd=fd` and `return_objects=True`.
Expand All @@ -584,7 +625,7 @@ def load(

def from_file(
filename: str, skip_json_loads: bool = False, logging: bool = False
) -> Union[Dict[str, Any], List[Any], str, float, int, bool, None]:
) -> JSONReturnType | Tuple[JSONReturnType, List[Dict[str, str]]]:
"""
This function is a wrapper around `load()` so you can pass the filename as string
"""
Expand All @@ -593,31 +634,3 @@ def from_file(
fd.close()

return jsonobj


class StringFileWrapper:
# This is a trick to simplify the code above, transform the filedescriptor handling into an array handling
def __init__(self, fd: TextIO) -> None:
self.fd = fd
self.length = None

def __getitem__(self, index: int) -> Any:
if isinstance(index, slice):
self.fd.seek(index.start)
value = self.fd.read(index.stop - index.start)
self.fd.seek(index.start)
return value
else:
self.fd.seek(index)
return self.fd.read(1)

def __len__(self) -> int:
if not self.length:
current_position = self.fd.tell()
self.fd.seek(0, os.SEEK_END)
self.length = self.fd.tell()
self.fd.seek(current_position)
return self.length

def __setitem__(self):
raise Exception("This is read-only!")

0 comments on commit 5970b37

Please sign in to comment.