Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
HaFred committed Feb 21, 2025
1 parent 1558518 commit f0e2ae3
Show file tree
Hide file tree
Showing 18 changed files with 140 additions and 438 deletions.
62 changes: 17 additions & 45 deletions examples/janus/demo/app_januspro.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,14 @@

# args
parser = argparse.ArgumentParser()
parser.add_argument(
"--ms_mode", type=int, default=1, help="mindspore mode, 0: graph, 1: pynative"
)
parser.add_argument("--ms_mode", type=int, default=1, help="mindspore mode, 0: graph, 1: pynative")
parser.add_argument(
"--model_path",
type=str,
default="ckpts/Janus-Pro-7B",
help="path to model weight folder",
)
parser.add_argument(
"--share", type=str2bool, default=False, help="private or share demo (public)"
)
parser.add_argument("--share", type=str2bool, default=False, help="private or share demo (public)")
args = parser.parse_args()

# ms init
Expand All @@ -48,9 +44,7 @@
config = AutoConfig.from_pretrained(args.model_path)
language_config = config.language_config
language_config._attn_implementation = "eager"
vl_gpt = AutoModelForCausalLM.from_pretrained(
args.model_path, language_config=language_config, trust_remote_code=True
)
vl_gpt = AutoModelForCausalLM.from_pretrained(args.model_path, language_config=language_config, trust_remote_code=True)

vl_gpt = set_model_param_dtype(vl_gpt, ms.bfloat16)
vl_gpt.set_train(False)
Expand Down Expand Up @@ -78,9 +72,9 @@ def multimodal_understanding(image, question, seed, top_p, temperature):

pil_images = [Image.fromarray(image)]

prepare_inputs = vl_chat_processor(
conversations=conversation, images=pil_images, force_batchify=True
).to(ms.bfloat16)
prepare_inputs = vl_chat_processor(conversations=conversation, images=pil_images, force_batchify=True).to(
ms.bfloat16
)

inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)

Expand Down Expand Up @@ -122,12 +116,8 @@ def generate(
if i % 2 != 0:
tokens[i, 1:-1] = vl_chat_processor.pad_id

inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens).to(
vl_gpt.dtype
)
generated_tokens = mint.zeros(
(parallel_size, image_token_num_per_image), dtype=ms.int32
)
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens).to(vl_gpt.dtype)
generated_tokens = mint.zeros((parallel_size, image_token_num_per_image), dtype=ms.int32)

use_cache = False
outputs = None
Expand All @@ -152,18 +142,14 @@ def generate(
next_token = mint.argmax(logits, dim=-1, keepdim=True)

generated_tokens[:, i] = next_token.squeeze(axis=-1)
next_token = mint.cat(
[next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1
).view(-1)
next_token = mint.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)

img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)

if use_cache:
inputs_embeds = img_embeds.unsqueeze(dim=1)
else:
inputs_embeds = ops.concat(
(inputs_embeds, img_embeds.unsqueeze(dim=1)), axis=1
)
inputs_embeds = ops.concat((inputs_embeds, img_embeds.unsqueeze(dim=1)), axis=1)

patches = vl_gpt.gen_vision_model.decode_code(
generated_tokens.to(dtype=ms.int32),
Expand Down Expand Up @@ -216,9 +202,7 @@ def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0):
parallel_size=parallel_size,
temperature=t2i_temperature,
)
images = unpack(
patches, width // 16 * 16, height // 16 * 16, parallel_size=parallel_size
)
images = unpack(patches, width // 16 * 16, height // 16 * 16, parallel_size=parallel_size)

# return [Image.fromarray(images[i]).resize((768, 768), Image.LANCZOS) for i in range(parallel_size)]
return [Image.fromarray(images[i]) for i in range(parallel_size)]
Expand All @@ -232,12 +216,8 @@ def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0):
with gr.Column():
question_input = gr.Textbox(label="Question")
und_seed_input = gr.Number(label="Seed", precision=0, value=42)
top_p = gr.Slider(
minimum=0, maximum=1, value=0.95, step=0.05, label="top_p"
)
temperature = gr.Slider(
minimum=0, maximum=1, value=0.1, step=0.05, label="temperature"
)
top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")

understanding_button = gr.Button("Chat")
understanding_output = gr.Textbox(label="Response")
Expand All @@ -260,16 +240,10 @@ def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0):
gr.Markdown(value="# Text-to-Image Generation")

with gr.Row():
cfg_weight_input = gr.Slider(
minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight"
)
t2i_temperature = gr.Slider(
minimum=0, maximum=1, value=1.0, step=0.05, label="temperature"
)
cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
t2i_temperature = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="temperature")

prompt_input = gr.Textbox(
label="Prompt. (Prompt in more detail can help produce better images!)"
)
prompt_input = gr.Textbox(label="Prompt. (Prompt in more detail can help produce better images!)")
seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)

generation_button = gr.Button("Generate Images")
Expand Down Expand Up @@ -316,6 +290,4 @@ def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0):
if args.share:
demo.launch(share=True)
else:
demo.queue(concurrency_count=1, max_size=10).launch(
server_name="127.0.0.1", server_port=37906, root_path="/path"
)
demo.queue(concurrency_count=1, max_size=10).launch(server_name="127.0.0.1", server_port=37906, root_path="/path")
42 changes: 10 additions & 32 deletions examples/janus/generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,26 +46,18 @@ def generate(

inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens).to(mmgpt.dtype)

generated_tokens = mint.zeros(
(parallel_size, image_token_num_per_image), dtype=ms.int32
)
generated_tokens = mint.zeros((parallel_size, image_token_num_per_image), dtype=ms.int32)

if use_cache:
init_kv = ms.mutable(
mmgpt.language_model.model.prepare_static_cache(
inputs_embeds, args.max_new_tokens
)
)
init_kv = ms.mutable(mmgpt.language_model.model.prepare_static_cache(inputs_embeds, args.max_new_tokens))
# pad input emb for aligning the shape, meets graph mode
emb_length = inputs_embeds.shape[-1] if inputs_embeds is not None else 0
padded_inputs_embeds = ops.zeros(
(inputs_embeds.shape[0], args.max_new_tokens, emb_length),
inputs_embeds.dtype if inputs_embeds is not None else None,
)
for batch_idx in range(inputs_embeds.shape[0]):
padded_inputs_embeds[batch_idx, : inputs_embeds.shape[1]] = inputs_embeds[
batch_idx
][:]
padded_inputs_embeds[batch_idx, : inputs_embeds.shape[1]] = inputs_embeds[batch_idx][:]
inputs_embeds = padded_inputs_embeds
else:
init_kv = None
Expand All @@ -78,9 +70,7 @@ def generate(
outputs = mmgpt.language_model.model(
inputs_embeds=inputs_embeds,
use_cache=use_cache,
past_key_values=ms.mutable(outputs[1])
if (i != 0 and use_cache)
else init_kv,
past_key_values=ms.mutable(outputs[1]) if (i != 0 and use_cache) else init_kv,
return_dict=False,
)
hidden_states = outputs[0]
Expand All @@ -99,18 +89,14 @@ def generate(

generated_tokens[:, i] = next_token.squeeze(axis=-1)

next_token = mint.cat(
[next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1
).view(-1)
next_token = mint.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)

img_embeds = mmgpt.prepare_gen_img_embeds(next_token)

if use_cache:
inputs_embeds = img_embeds.unsqueeze(dim=1)
else:
inputs_embeds = ops.concat(
(inputs_embeds, img_embeds.unsqueeze(dim=1)), axis=1
)
inputs_embeds = ops.concat((inputs_embeds, img_embeds.unsqueeze(dim=1)), axis=1)

time_cost = time() - st
print(
Expand Down Expand Up @@ -141,9 +127,7 @@ def generate(

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--ms_mode", type=int, default=1, help="mindspore mode, 0: graph, 1: pynative"
)
parser.add_argument("--ms_mode", type=int, default=1, help="mindspore mode, 0: graph, 1: pynative")
parser.add_argument(
"--prompt",
type=str,
Expand All @@ -168,9 +152,7 @@ def generate(
default="ckpts/Janus-Pro-1B",
help="path to model weight folder",
)
parser.add_argument(
"--use_cache", type=str2bool, default=True, help="use kv cache or not"
)
parser.add_argument("--use_cache", type=str2bool, default=True, help="use kv cache or not")
parser.add_argument("--seed", type=int, default=42, help="random seed")
parser.add_argument("--max_new_tokens", type=int, default=1024)
args = parser.parse_args()
Expand All @@ -182,14 +164,10 @@ def generate(
set_random_seed(args.seed)

# specify the path to the model
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(
args.model_path
)
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(args.model_path)
tokenizer = vl_chat_processor.tokenizer

vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
args.model_path
)
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(args.model_path)
dtype = ms.bfloat16
vl_gpt = set_model_param_dtype(vl_gpt, dtype)
vl_gpt.set_train(False)
Expand Down
22 changes: 7 additions & 15 deletions examples/janus/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def multimodal_understanding(
tokenizer = vl_chat_processor.tokenizer

pil_images = load_pil_images(conversation)
prepare_inputs = vl_chat_processor(
conversations=conversation, images=pil_images, force_batchify=True
).to(vl_gpt.dtype)
prepare_inputs = vl_chat_processor(conversations=conversation, images=pil_images, force_batchify=True).to(
vl_gpt.dtype
)

inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
st = time()
Expand Down Expand Up @@ -81,15 +81,9 @@ def multimodal_understanding(

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--ms_mode", type=int, default=1, help="mindspore mode, 0: graph, 1: pynative"
)
parser.add_argument(
"--image", type=str, default="images/doge.png", help="path to input image"
)
parser.add_argument(
"--question", type=str, default="explain this meme", help="path to input image"
)
parser.add_argument("--ms_mode", type=int, default=1, help="mindspore mode, 0: graph, 1: pynative")
parser.add_argument("--image", type=str, default="images/doge.png", help="path to input image")
parser.add_argument("--question", type=str, default="explain this meme", help="path to input image")
parser.add_argument(
"--model_path",
type=str,
Expand Down Expand Up @@ -118,9 +112,7 @@ def multimodal_understanding(
ms.set_context(jit_config={"jit_level": "O0"})

# specify the path to the model
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(
args.model_path
)
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(args.model_path)

config = AutoConfig.from_pretrained(args.model_path)
language_config = config.language_config
Expand Down
28 changes: 7 additions & 21 deletions examples/janus/interactivechat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer

vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
model_path, trust_remote_code=True
)
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
vl_gpt = vl_gpt.to(ms.bfloat16)


Expand Down Expand Up @@ -62,9 +60,7 @@ def generate(

inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens)

generated_tokens = mint.zeros(
(parallel_size, image_token_num_per_image), dtype=ms.int32
)
generated_tokens = mint.zeros((parallel_size, image_token_num_per_image), dtype=ms.int32)
outputs = None # Initialize outputs for use in the loop

for i in range(image_token_num_per_image):
Expand All @@ -85,9 +81,7 @@ def generate(
next_token = mint.multinomial(probs, num_samples=1)
generated_tokens[:, i] = next_token.squeeze(dim=-1)

next_token = mint.cat(
[next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1
).view(-1)
next_token = mint.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
inputs_embeds = img_embeds.unsqueeze(dim=1)

Expand All @@ -112,9 +106,7 @@ def generate(

# Save images with timestamp and part of the user prompt in the filename
for i in range(parallel_size):
save_path = os.path.join(
"generated_samples", f"img_{timestamp}_{short_prompt}_{i}.jpg"
)
save_path = os.path.join("generated_samples", f"img_{timestamp}_{short_prompt}_{i}.jpg")
PIL.Image.fromarray(visual_img[i]).save(save_path)


Expand All @@ -123,19 +115,15 @@ def interactive_image_generator():

# Ask for the number of images at the start of the session
while True:
num_images_input = input(
"How many images would you like to generate per prompt? (Enter a positive integer): "
)
num_images_input = input("How many images would you like to generate per prompt? (Enter a positive integer): ")
if num_images_input.isdigit() and int(num_images_input) > 0:
parallel_size = int(num_images_input)
break
else:
print("Invalid input. Please enter a positive integer.")

while True:
user_input = input(
"Please describe the image you'd like to generate (or type 'exit' to quit): "
)
user_input = input("Please describe the image you'd like to generate (or type 'exit' to quit): ")

if user_input.lower() == "exit":
print("Exiting the image generator. Goodbye!")
Expand All @@ -155,9 +143,7 @@ def interactive_image_generator():
parallel_size=parallel_size, # Pass the user-specified number of images
)

print(
"Image generation complete! Check the 'generated_samples' folder for the output.\n"
)
print("Image generation complete! Check the 'generated_samples' folder for the output.\n")


if __name__ == "__main__":
Expand Down
4 changes: 1 addition & 3 deletions examples/janus/janus/models/clip_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ def __init__(
"select_layer": select_layer,
}
vision_tower_params.update(kwargs)
self.vision_tower, self.forward_kwargs = self.build_vision_tower(
vision_tower_params
)
self.vision_tower, self.forward_kwargs = self.build_vision_tower(vision_tower_params)

if pixel_mean is not None and pixel_std is not None:
image_norm = vision.transforms.Normalize(mean=pixel_mean, std=pixel_std)
Expand Down
4 changes: 1 addition & 3 deletions examples/janus/janus/models/image_processing_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,7 @@ def resize(self, pil_img: Image) -> np.ndarray:
else:
from mindspore.dataset.vision import Inter

pil_img = ms.dataset.vision.Resize(size, interpolation=Inter.ANTIALIAS)(
pil_img
)
pil_img = ms.dataset.vision.Resize(size, interpolation=Inter.ANTIALIAS)(pil_img)

pil_img = expand2square(pil_img, self.background_color)
x = to_numpy_array(pil_img)
Expand Down
Loading

0 comments on commit f0e2ae3

Please sign in to comment.