Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multisession modeling changes for AutoLFADS #50

Open
wants to merge 10 commits into
base: reorganize
Choose a base branch
from
37 changes: 37 additions & 0 deletions src/neurocaas_contrib/Interface_S3.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,43 @@ def download(s3path,localpath,display = False):
else:
raise

def download_multi(s3path,localpath,force,display = False):
"""Download function. Takes an s3 path to a "folder" (path prefix that ends with backslack), and local object path as input. Will attempt to download all data at the given location to the local path.
:param s3path: full path to an object in s3. Assumes the s3://bucketname/key syntax.
:param localpath: full path to the object name locally (i.e. with basename attached).
:param force: will not redownload if data of the same name already lives here
:param display: (optional) Defaults to false. If true, displays a progress bar.
:return: bool (True if successful download for all files, False otherwise)


"""
assert s3path.startswith("s3://")
bucketname,keyname = s3path.split("s3://")[-1].split("/",1)

try:
transfer = S3Transfer(s3_client)


# adapted from https://stackoverflow.com/questions/49772151/download-a-folder-from-s3-using-boto3
bucket = s3.Bucket(bucketname)
no_duplicate = 1
for obj in bucket.objects.filter(Prefix = keyname):
obj_keyname = obj.key
if (os.path.basename(obj_keyname) in os.listdir(localpath)) and (not force):
print("Data already exists at this location. Set force = true to overwrite")
no_duplicate = 0
else:
progress = ProgressPercentage_d(transfer._manager._client,bucketname,obj_keyname,display = display)
transfer.download_file(bucketname,obj_keyname,os.path.join(localpath,os.path.basename(obj_keyname)),callback = progress)
return no_duplicate

except botocore.exceptions.ClientError as e:
if e.response['Error']['Code'] == "404":
print("The object does not exist.")
raise
else:
raise

def upload(localpath,s3path,display = False):
"""Upload function. Takes a local object paht and s3 path to the desired key as input.
:param localpath: full path to the object name locally (i.e. with basename attached).
Expand Down
29 changes: 29 additions & 0 deletions src/neurocaas_contrib/cli_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,35 @@ def get_data(obj,outputpath,force,display):
kwargs["display"] = display
ncsm.get_data(**kwargs)


@workflow.command(help = "get multiple registered datasets from S3")
@click.option("-o",
"--outputpath",
help = "path to write output to.",
default = None)
@click.option("-f",
"--force",
help = "if true, will redownload even if exists at intended output location",
is_flag = True)
@click.option("-d",
"--display",
help = "if true, will show download progress",
is_flag = True)
@click.pass_obj
def get_data_multi(obj,outputpath,force,display):
"""Gets multiple registered datasets from S3.

"""
path = obj["storage"]["path"]
ncsm = NeuroCAASScriptManager.from_registration(path)
kwargs = {}
if outputpath is not None:
kwargs["path"] = outputpath
kwargs["force"] = force
kwargs["display"] = display
ncsm.get_data_multi(**kwargs)


@workflow.command(help = "get a registered config from S3")
@click.option("-o",
"--outputpath",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
{
"PipelineName": "autolfads-torch",
"REGION": "us-east-1",
"STAGE": "websubstack",
"Lambda": {
"CodeUri": "../../protocols",
"Handler": "submit_start.handler_multisession",
"Launch": true,
"LambdaConfig": {
"AMI": "ami-08f59b6fdb19d0344",
"INSTANCE_TYPE": "p2.8xlarge",
"REGION": "us-east-1",
"IAM_ROLE": "SSMRole",
"KEY_NAME": "testkeystack-custom-dev-key-pair",
"WORKING_DIRECTORY": "~/bin",
"COMMAND": "cd /home/ubuntu; sudo -u ubuntu neurocaas_contrib/run_main_cli.sh \"{}\" \"{}\" \"{}\" \"{}\" \"lfads-torch/run_main.sh\"; . neurocaas_contrib/ncap_utils/workflow.sh;",
"EXECUTION_TIMEOUT": 900,
"SSM_TIMEOUT": 172000
}
},
"UXData": {
"Affiliates": [
{
"AffiliateName": "traviscipermagroup",
"UserNames": [
"cipermauser1",
"cipermauser2"
],
"UserInput": true,
"ContactEmail": "NOTE: KEEP THIS AFFILIATE TO ENABLE EASY TESTING"
},
{
"AffiliateName": "systemsneura1639759179",
"UserNames": [
"systemsneura1639759178"
],
"UserInput": true,
"ContactEmail": "NOTE: KEEP THIS AFFILIATE TO ENABLE EASY TESTING"
}
]
}
}
85 changes: 64 additions & 21 deletions src/neurocaas_contrib/scripting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import json
import zipfile
from .log import NeuroCAASCertificate,NeuroCAASDataStatus,NeuroCAASDataStatusLegacy
from .Interface_S3 import download,upload
from .Interface_S3 import download,upload,download_multi

dir_loc = os.path.abspath(os.path.dirname(__file__))

Expand Down Expand Up @@ -166,7 +166,7 @@ def __init__(self,path,write = True):
self.path = path
## The subdirectories to expect/create at the given location.
self.subdirs = {"data":"inputs","config":"configs","results":"results","logs":"logs"}
#self.pathtemplate = {"s3":None,"localsource":None,"local":None}
#self.pathtemplate = {"s3":None,"local":None,"local":None}
self.registration = {
"data":{},
"config":{},
Expand All @@ -178,6 +178,7 @@ def __init__(self,path,write = True):
self.write()

def write(self):
print("\n\n\n\n" + "Registering at {}".format(str(self.path)) + "\n\n\n\n")
with open(os.path.join(self.path,"registration.json"),"w") as reg:
json.dump(self.registration,reg)

Expand Down Expand Up @@ -205,7 +206,7 @@ def register_data(self,s3path):
## canc check existence later.
assert str(s3path).startswith("s3://"), "must be given in s3 form"
self.registration["data"]["s3"] = str(s3path)
self.registration["data"].pop("localsource","False")
self.registration["data"].pop("local","False")
self.registration["data"].pop("local","False")
self.write()

Expand All @@ -215,7 +216,7 @@ def register_data_local(self,localpath):

"""
## canc check existence later.
self.registration["data"]["localsource"] = str(localpath)
self.registration["data"]["local"] = str(localpath)
self.registration["data"].pop("s3","False")
self.registration["data"].pop("local","False")
self.write()
Expand All @@ -228,7 +229,7 @@ def register_config(self,s3path):
## canc check existence later.
assert str(s3path).startswith("s3://"), "must be given in s3 form"
self.registration["config"]["s3"] = str(s3path)
self.registration["config"].pop("localsource","False")
self.registration["config"].pop("local","False")
self.registration["config"].pop("local","False")
self.write()

Expand All @@ -238,7 +239,7 @@ def register_config_local(self,localpath):

"""
## canc check existence later.
self.registration["config"]["localsource"] = str(localpath)
self.registration["config"]["local"] = str(localpath)
self.registration["config"].pop("s3","False")
self.registration["config"].pop("local","False")
self.write()
Expand All @@ -255,7 +256,7 @@ def register_file(self,name,s3path):
self.registration["additional_files"][name] = {}
## populate
self.registration["additional_files"][name]["s3"] = str(s3path)
self.registration["additional_files"][name].pop("localsource","False")
self.registration["additional_files"][name].pop("local","False")
self.registration["additional_files"][name].pop("local","False")
self.write()

Expand All @@ -269,7 +270,7 @@ def register_file_local(self,name,localpath):
#self.registration["additional_files"][name] = {k:v for k,v in self.pathtemplate.items()}
self.registration["additional_files"][name] = {}
## populate
self.registration["additional_files"][name]["localsource"] = str(localpath)
self.registration["additional_files"][name]["local"] = str(localpath)
self.registration["additional_files"][name].pop("s3","False")
self.registration["additional_files"][name].pop("local","False")
self.write()
Expand All @@ -280,14 +281,14 @@ def register_resultpath(self,s3path):
"""
assert s3path.startswith("s3://"), "must be given in s3 form"
self.registration["resultpath"]["s3"] = str(s3path)
self.registration["resultpath"].pop("localsource","False")
self.registration["resultpath"].pop("local","False")
self.write()

def register_resultpath_local(self,localpath):
"""Given an local path, registers that as the location where we will upload job data. Give a folder, where you want to generate two subdirectories, "logs", and "process_results". Logs and analysis results will be sent to these respective locations.

"""
self.registration["resultpath"]["localsource"] = str(localpath)
self.registration["resultpath"]["local"] = str(localpath)
self.registration["resultpath"].pop("s3","False")
self.write()

Expand All @@ -305,8 +306,8 @@ def get_data(self,path = None,force = False,display = False):
source = "s3"
except KeyError:
try:
data_localsource = self.registration["data"]["localsource"]
data_name = os.path.basename(data_localsource)
data_local = self.registration["data"]["local"]
data_name = os.path.basename(data_local)
source = "local"
except:
raise AssertionError("Data not registered. Run register_data first.")
Expand All @@ -325,10 +326,52 @@ def get_data(self,path = None,force = False,display = False):
if source == "s3":
download(data_s3path,data_localpath,display)
elif source == "local":
shutil.copy(data_localsource,data_localpath)
shutil.copy(data_local,data_localpath)
self.registration["data"]["local"] = data_localpath
self.write()
return 1

def get_data_multi(self,path = None,force = False,display = False):
"""Get currently registered data. If desired, you can pass a path where you would like data to be moved. Otherwise, it will be moved to self.path/self.subdirs[data]
:param path: (optional) the location you want to write data to.
:param force: (optional) by default, will not redownload if data of the same name already lives here. Can override with force = True
:param display: (optional) by default, will not display downlaod progress.
:return: bool (True if downloaded, False if not)

"""
try:
data_s3path = self.registration["data"]["s3"]
source = "s3"
except KeyError:
try:
data_local = self.registration["data"]["local"]
source = "local"
except:
raise AssertionError("Data not registered. Run register_data first.")

if path is None:
path = os.path.join(self.path,self.subdirs["data"])
mkdir_notexists(path)
#pass the local directory instead of a filename -- will populate with all files in remote dir
data_localpath = path

if source == "s3":
download_success = download_multi(data_s3path,data_localpath,force,display)
if not download_success:
return 0
elif source == "local":
for filename in os.listdir(data_local):
source_name = os.path.join(data_local,filename)
dest_name = os.path.join(data_localpath,filename)

if os.path.exists(dest_name) and not force:
print(f"{filename} already exists at this location. Set force = true to overwrite")
return 0

shutil.copy(source_name,dest_name)
self.registration["data"]["local"] = str(data_localpath)
self.write()
return 1

def get_config(self,path = None,force = False,display = False):
"""Get currently registered config. If desired, you can pass a path where you would like config to be moved. Otherwise, it will be moved to self.path/self.subdirs[config]
Expand All @@ -344,8 +387,8 @@ def get_config(self,path = None,force = False,display = False):
source = "s3"
except KeyError:
try:
config_localsource = self.registration["config"]["localsource"]
config_name = os.path.basename(config_localsource)
config_local = self.registration["config"]["local"]
config_name = os.path.basename(config_local)
source = "local"
except:
raise AssertionError("Config not registered. Run register_config first.")
Expand All @@ -364,7 +407,7 @@ def get_config(self,path = None,force = False,display = False):
if source == "s3":
download(config_s3path,config_localpath,display)
elif source == "local":
shutil.copy(config_localsource,config_localpath)
shutil.copy(config_local,config_localpath)
self.registration["config"]["local"] = config_localpath
self.write()
return 1
Expand All @@ -385,8 +428,8 @@ def get_file(self,varname,path = None,force = False,display = False):
source = "s3"
except KeyError:
try:
file_localsource = self.registration["additional_files"][varname]["localsource"]
file_name = os.path.basename(file_localsource)
file_local = self.registration["additional_files"][varname]["local"]
file_name = os.path.basename(file_local)
source = "local"
except:
raise AssertionError("File not registered. Run register_file first.")
Expand All @@ -405,7 +448,7 @@ def get_file(self,varname,path = None,force = False,display = False):
if source == "s3":
download(file_s3path,file_localpath,display)
elif source == "local":
shutil.copy(file_localsource,file_localpath)
shutil.copy(file_local,file_localpath)
self.registration["additional_files"][varname]["local"] = file_localpath
self.write()
return 1
Expand All @@ -422,7 +465,7 @@ def put_result(self,localfile,display = False):
upload(localfile,fullpath,display)
except KeyError:
try:
fullpath = os.path.join(self.registration["resultpath"]["localsource"],"process_results",filename)
fullpath = os.path.join(self.registration["resultpath"]["local"],"process_results",filename)
os.makedirs(os.path.dirname(fullpath),exist_ok = True)
shutil.copy(localfile,fullpath)
except:
Expand Down Expand Up @@ -521,7 +564,7 @@ def get_resultpath(self,filepath):
resultpath = os.path.join(self.registration["resultpath"]["s3"],"process_results",basename)
except KeyError:
try:
resultpath = os.path.join(self.registration["resultpath"]["localsource"],"process_results",basename)
resultpath = os.path.join(self.registration["resultpath"]["local"],"process_results",basename)
except KeyError:
raise KeyError("Not registered.")
return resultpath
Expand Down
Loading