Skip to content

Commit

Permalink
Merge branch 'main' of github.com:myshell-ai/ShellAgent into 1-bug-fi…
Browse files Browse the repository at this point in the history
…xes-and-improvements-0918
  • Loading branch information
myshell-joe committed Sep 20, 2024
2 parents 0c4808c + 0b8efb5 commit 2827ff5
Show file tree
Hide file tree
Showing 12 changed files with 441 additions and 79 deletions.
163 changes: 96 additions & 67 deletions folder_paths.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from __future__ import annotations

import os
import time
import mimetypes
import logging
import tempfile
import requests
from urllib.parse import urlparse
from typing import Any, IO
from typing import Set, List, Dict, Tuple, Literal
from collections.abc import Collection

supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl'])
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}

folder_names_and_paths = {}
folder_names_and_paths: dict[str, tuple[list[str], set[str]]] = {}

base_path = os.path.dirname(os.path.realpath(__file__))
models_dir = os.path.join(base_path, "models")
Expand All @@ -18,7 +19,7 @@
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
folder_names_and_paths["unet"] = ([os.path.join(models_dir, "unet")], supported_pt_extensions)
folder_names_and_paths["diffusion_models"] = ([os.path.join(models_dir, "unet"), os.path.join(models_dir, "diffusion_models")], supported_pt_extensions)
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
Expand All @@ -30,7 +31,7 @@

folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)

folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], [])
folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], set())

folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)

Expand All @@ -43,41 +44,56 @@
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user")

filename_list_cache = {}
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}

extension_mimetypes_cache = {
"webp" : "image",
}

def map_legacy(folder_name: str) -> str:
legacy = {"unet": "diffusion_models"}
return legacy.get(folder_name, folder_name)

if not os.path.exists(input_directory):
try:
os.makedirs(input_directory)
except:
logging.error("Failed to create input directory")

def set_output_directory(output_dir):
def set_output_directory(output_dir: str) -> None:
global output_directory
output_directory = output_dir

def set_temp_directory(temp_dir):
def set_temp_directory(temp_dir: str) -> None:
global temp_directory
temp_directory = temp_dir

def set_input_directory(input_dir):
def set_input_directory(input_dir: str) -> None:
global input_directory
input_directory = input_dir

def get_output_directory():
def get_output_directory() -> str:
global output_directory
return output_directory

def get_temp_directory():
def get_temp_directory() -> str:
global temp_directory
return temp_directory

def get_input_directory():
def get_input_directory() -> str:
global input_directory
return input_directory

def get_user_directory() -> str:
return user_directory

def set_user_directory(user_dir: str) -> None:
global user_directory
user_directory = user_dir


#NOTE: used in http server so don't put folders that should not be accessed remotely
def get_directory_by_type(type_name):
def get_directory_by_type(type_name: str) -> str | None:
if type_name == "output":
return get_output_directory()
if type_name == "temp":
Expand All @@ -86,39 +102,32 @@ def get_directory_by_type(type_name):
return get_input_directory()
return None


def is_valid_url(candidate_str: Any) -> bool:
if not isinstance(candidate_str, str):
return False
parsed = urlparse(candidate_str)
return parsed.scheme != "" and parsed.netloc != ""

def _make_temp_file(file_path: str) -> IO:
def filter_files_content_types(files: List[str], content_types: Literal["image", "video", "audio"]) -> List[str]:
"""
A utility function to write bytes to a temporary file. This is useful
if one needs to pass a file object to a function, but only has bytes.
Example:
files = os.listdir(folder_paths.get_input_directory())
filter_files_content_types(files, ["image", "audio", "video"])
"""
# If the source is a valid url, we will download the content and return it.
try:
content = requests.get(file_path).content
except Exception:
raise ValueError(f"Failed to download content from url: {file_path}")

_, extension = os.path.splitext(file_path)

# get the temp file path
f = tempfile.NamedTemporaryFile(delete=False, suffix=extension)
f.write(content)
# Flush to make sure that the content is written.
f.flush()
# Seek to the beginning of the file so that the content can be read.
f.seek(0)
print(f'download url resources success to {f.name}')
return f
global extension_mimetypes_cache
result = []
for file in files:
extension = file.split('.')[-1]
if extension not in extension_mimetypes_cache:
mime_type, _ = mimetypes.guess_type(file, strict=False)
if not mime_type:
continue
content_type = mime_type.split('/')[0]
extension_mimetypes_cache[extension] = content_type
else:
content_type = extension_mimetypes_cache[extension]

if content_type in content_types:
result.append(file)
return result

# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
# otherwise use default_path as base_dir
def annotated_filepath(name):
def annotated_filepath(name: str) -> tuple[str, str | None]:
if name.endswith("[output]"):
base_dir = get_output_directory()
name = name[:-9]
Expand All @@ -134,16 +143,10 @@ def annotated_filepath(name):
return name, base_dir


def get_annotated_filepath(name, default_dir=None):
# norm path
def get_annotated_filepath(name: str, default_dir: str | None=None) -> str:
name, base_dir = annotated_filepath(name)

if base_dir is None:
# find a https url
if is_valid_url(name):
src = _make_temp_file(name)
return src.name

if default_dir is not None:
base_dir = default_dir
else:
Expand All @@ -152,7 +155,7 @@ def get_annotated_filepath(name, default_dir=None):
return os.path.join(base_dir, name)


def exists_annotated_filepath(name):
def exists_annotated_filepath(name) -> bool:
name, base_dir = annotated_filepath(name)

if base_dir is None:
Expand All @@ -162,17 +165,19 @@ def exists_annotated_filepath(name):
return os.path.exists(filepath)


def add_model_folder_path(folder_name, full_folder_path):
def add_model_folder_path(folder_name: str, full_folder_path: str) -> None:
global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name in folder_names_and_paths:
folder_names_and_paths[folder_name][0].append(full_folder_path)
else:
folder_names_and_paths[folder_name] = ([full_folder_path], set())

def get_folder_paths(folder_name):
def get_folder_paths(folder_name: str) -> list[str]:
folder_name = map_legacy(folder_name)
return folder_names_and_paths[folder_name][0][:]

def recursive_search(directory, excluded_dir_names=None):
def recursive_search(directory: str, excluded_dir_names: list[str] | None=None) -> tuple[list[str], dict[str, float]]:
if not os.path.isdir(directory):
return [], {}

Expand All @@ -189,14 +194,18 @@ def recursive_search(directory, excluded_dir_names=None):
logging.warning(f"Warning: Unable to access {directory}. Skipping this path.")

logging.debug("recursive file list on directory {}".format(directory))
dirpath: str
subdirs: list[str]
filenames: list[str]

for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
for file_name in filenames:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)

for d in subdirs:
path = os.path.join(dirpath, d)
path: str = os.path.join(dirpath, d)
try:
dirs[path] = os.path.getmtime(path)
except FileNotFoundError:
Expand All @@ -205,13 +214,14 @@ def recursive_search(directory, excluded_dir_names=None):
logging.debug("found {} files".format(len(result)))
return result, dirs

def filter_files_extensions(files, extensions):
def filter_files_extensions(files: Collection[str], extensions: Collection[str]) -> list[str]:
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files)))



def get_full_path(folder_name, filename):
def get_full_path(folder_name: str, filename: str) -> str | None:
global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name not in folder_names_and_paths:
return None
folders = folder_names_and_paths[folder_name]
Expand All @@ -225,7 +235,16 @@ def get_full_path(folder_name, filename):

return None

def get_filename_list_(folder_name):

def get_full_path_or_raise(folder_name: str, filename: str) -> str:
full_path = get_full_path(folder_name, filename)
if full_path is None:
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
return full_path


def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
folder_name = map_legacy(folder_name)
global folder_names_and_paths
output_list = set()
folders = folder_names_and_paths[folder_name]
Expand All @@ -235,11 +254,12 @@ def get_filename_list_(folder_name):
output_list.update(filter_files_extensions(files, folders[1]))
output_folders = {**output_folders, **folders_all}

return (sorted(list(output_list)), output_folders, time.perf_counter())
return sorted(list(output_list)), output_folders, time.perf_counter()

def cached_filename_list_(folder_name):
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
global filename_list_cache
global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name not in filename_list_cache:
return None
out = filename_list_cache[folder_name]
Expand All @@ -258,30 +278,39 @@ def cached_filename_list_(folder_name):

return out

def get_filename_list(folder_name):
def get_filename_list(folder_name: str) -> list[str]:
folder_name = map_legacy(folder_name)
out = cached_filename_list_(folder_name)
if out is None:
out = get_filename_list_(folder_name)
global filename_list_cache
filename_list_cache[folder_name] = out
return list(out[0])

def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0):
def map_filename(filename):
def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]:
def map_filename(filename: str) -> tuple[int, str]:
prefix_len = len(os.path.basename(filename_prefix))
prefix = filename[:prefix_len + 1]
try:
digits = int(filename[prefix_len + 1:].split('_')[0])
except:
digits = 0
return (digits, prefix)
return digits, prefix

def compute_vars(input, image_width, image_height):
def compute_vars(input: str, image_width: int, image_height: int) -> str:
input = input.replace("%width%", str(image_width))
input = input.replace("%height%", str(image_height))
now = time.localtime()
input = input.replace("%year%", str(now.tm_year))
input = input.replace("%month%", str(now.tm_mon).zfill(2))
input = input.replace("%day%", str(now.tm_mday).zfill(2))
input = input.replace("%hour%", str(now.tm_hour).zfill(2))
input = input.replace("%minute%", str(now.tm_min).zfill(2))
input = input.replace("%second%", str(now.tm_sec).zfill(2))
return input

filename_prefix = compute_vars(filename_prefix, image_width, image_height)
if "%" in filename_prefix:
filename_prefix = compute_vars(filename_prefix, image_width, image_height)

subfolder = os.path.dirname(os.path.normpath(filename_prefix))
filename = os.path.basename(os.path.normpath(filename_prefix))
Expand Down
33 changes: 32 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions proconfig/widgets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import proconfig.widgets.imagen_widgets
import proconfig.widgets.language_models
import proconfig.widgets.tools
import proconfig.widgets.myshell_widgets
# load custom widgets

import os
Expand Down
1 change: 1 addition & 0 deletions proconfig/widgets/myshell_widgets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from proconfig.widgets.myshell_widgets.tools.image_text_fuser import ImageTextFuserWidget
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit 2827ff5

Please sign in to comment.