Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add parameters --modelfile #1286

Merged
merged 4 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions visualdl/reader/graph_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class GraphReader(object):
"""Graph reader to read vdl graph files, support for frontend api in lib.py.
"""

def __init__(self, logdir=''):
def __init__(self, logdir='', model_name=''):
"""Instance of GraphReader

Args:
Expand All @@ -52,6 +52,7 @@ def __init__(self, logdir=''):
else:
self.dir = logdir

self.model_name = model_name
self.walks = {}
self.displayname2runs = {}
self.runs2displayname = {}
Expand Down Expand Up @@ -102,7 +103,10 @@ def graphs(self, update=False):
]
tags_temp.sort(reverse=True)
if len(tags_temp) > 0:
walks_temp.update({run: tags_temp[0]})
if self.model_name:
walks_temp.update({run: self.model_name})
else:
walks_temp.update({run: tags_temp[0]})
self.walks = walks_temp
return self.walks

Expand Down
14 changes: 9 additions & 5 deletions visualdl/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,13 @@ def try_call(function, *args, **kwargs):


class Api(object):
def __init__(self, logdir, model, cache_timeout):
def __init__(self, logdir, model, modelfile, cache_timeout):
self.model_name = ''
if not logdir and modelfile:
logdir = os.path.dirname(modelfile)
self.model_name = os.path.basename(modelfile)
self._reader = LogReader(logdir)
self._graph_reader = GraphReader(logdir)
self._graph_reader = GraphReader(logdir, self.model_name)
self._graph_reader.set_displayname(self._reader)
if model:
if 'vdlgraph' in model:
Expand Down Expand Up @@ -415,7 +419,7 @@ def get_component_tabs(*apis, vdl_args, request_args):
all_tabs = set()
if vdl_args.component_tabs:
return list(vdl_args.component_tabs)
if vdl_args.logdir:
if vdl_args.logdir or vdl_args.modelfile:
for api in apis:
all_tabs.update(api('component_tabs', request_args))
all_tabs.add('static_graph')
Expand All @@ -427,8 +431,8 @@ def get_component_tabs(*apis, vdl_args, request_args):
return list(all_tabs)


def create_api_call(logdir, model, cache_timeout):
api = Api(logdir, model, cache_timeout)
def create_api_call(logdir, model, modelfile, cache_timeout):
api = Api(logdir, model, modelfile, cache_timeout)
routes = {
'components': (api.components, []),
'runs': (api.runs, []),
Expand Down
2 changes: 1 addition & 1 deletion visualdl/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get_locale():
) # we add this to prevent SIGINT not work in multiprocess queue waiting
babel = Babel(app, locale_selector=get_locale) # noqa:F841
# Babel api from flask_babel v3.0.0
api_call = create_api_call(args.logdir, args.model, args.cache_timeout)
api_call = create_api_call(args.logdir, args.model, args.modelfile, args.cache_timeout)
profiler_api_call = create_profiler_api_call(args.logdir)
inference_api_call = create_model_convert_api_call()
fastdeploy_api_call = create_fastdeploy_api_call()
Expand Down
9 changes: 9 additions & 0 deletions visualdl/server/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, args):
self.api_only = args.get('api_only', False)
self.open_browser = args.get('open_browser', False)
self.model = args.get('model', '')
self.modelfile = args.get('modelfile', '')
self.product = args.get('product', default_product)
self.telemetry = args.get('telemetry', True)
self.theme = args.get('theme', None)
Expand Down Expand Up @@ -123,6 +124,7 @@ def __init__(self, **kwargs):
self.api_only = args.api_only
self.open_browser = args.open_browser
self.model = args.model
self.modelfile = args.modelfile
self.product = args.product
self.telemetry = args.telemetry
self.theme = args.theme
Expand All @@ -141,6 +143,13 @@ def parse_args():
epilog="For more information: https://github.com/PaddlePaddle/VisualDL"
)

parser.add_argument(
"--modelfile",
type=str,
action="store",
default="",
help="json model file path")

parser.add_argument(
"--logdir", action="store", nargs="+", help="log file directory")

Expand Down
Loading