Skip to content

Commit

Permalink
Merge pull request datajoint#95 from MilagrosMarin/element-dlc-test
Browse files Browse the repository at this point in the history
minor changes to standardize the elements
  • Loading branch information
kushalbakshi authored Nov 7, 2023
2 parents 1314488 + b33ece9 commit ded35b3
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 67 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,21 @@ tables that can be combined with other Elements to assemble a fully functional p

## Getting Started

+ Please fork this repository
+ Please fork this repository.

+ Clone the repository to your computer
+ Clone the repository to your computer.

```bash
git clone https://github.com/<enter_github_username>/element-deeplabcut
```

+ Install with `pip`
+ Install with `pip`:

```bash
pip install -e .
```

+ [Interactive tutorial on GitHub Codespaces](#interactive-tutorial)
+ [Interactive tutorial on GitHub Codespaces](https://github.com/datajoint/element-deeplabcut#interactive-tutorial)

+ [Documentation](https://datajoint.com/docs/elements/element-deeplabcut)

Expand Down
2 changes: 1 addition & 1 deletion element_deeplabcut/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@
"DLC_PROCESSED_DATA_DIR", dj.config["custom"].get("dlc_processed_data_dir", "")
)

db_prefix = dj.config["custom"].get("database.prefix", "")
db_prefix = dj.config["custom"].get("database.prefix", "")
143 changes: 81 additions & 62 deletions notebooks/tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,12 @@
"outputs": [],
"source": [
"import os\n",
"if os.path.basename(os.getcwd())=='notebooks': os.chdir('..')\n",
"assert os.path.basename(os.getcwd())=='element-deeplabcut', (\"Please move to the \"\n",
" + \"element directory\")"
"\n",
"if os.path.basename(os.getcwd()) == \"notebooks\":\n",
" os.chdir(\"..\")\n",
"assert os.path.basename(os.getcwd()) == \"element-deeplabcut\", (\n",
" \"Please move to the \" + \"element directory\"\n",
")"
]
},
{
Expand Down Expand Up @@ -201,7 +204,7 @@
}
],
"source": [
"from tutorial_pipeline import lab, subject, session, train, model "
"from tutorial_pipeline import lab, subject, session, train, model"
]
},
{
Expand Down Expand Up @@ -990,10 +993,10 @@
],
"source": [
"(\n",
" dj.Diagram(subject) \n",
" + dj.Diagram(lab) \n",
" + dj.Diagram(session) \n",
" + dj.Diagram(model) \n",
" dj.Diagram(subject)\n",
" + dj.Diagram(lab)\n",
" + dj.Diagram(session)\n",
" + dj.Diagram(model)\n",
" + dj.Diagram(train)\n",
")"
]
Expand Down Expand Up @@ -1274,7 +1277,9 @@
"metadata": {},
"outputs": [],
"source": [
"config_file_rel = \"./example_data/inbox/from_top_tracking-DataJoint-2023-10-11/config.yaml\""
"config_file_rel = (\n",
" \"./example_data/inbox/from_top_tracking-DataJoint-2023-10-11/config.yaml\"\n",
")"
]
},
{
Expand Down Expand Up @@ -1358,11 +1363,13 @@
}
],
"source": [
"model.Model.insert_new_model(model_name='from_top_tracking_model_test',\n",
" dlc_config=config_file_rel,\n",
" shuffle=1,\n",
" trainingsetindex=0,\n",
" model_description='Model in example data: from_top_tracking model')"
"model.Model.insert_new_model(\n",
" model_name=\"from_top_tracking_model_test\",\n",
" dlc_config=config_file_rel,\n",
" shuffle=1,\n",
" trainingsetindex=0,\n",
" model_description=\"Model in example data: from_top_tracking model\",\n",
")"
]
},
{
Expand Down Expand Up @@ -1668,14 +1675,14 @@
"metadata": {},
"outputs": [],
"source": [
"#Definition of the dictionary named \"session_keys\"\n",
"# Definition of the dictionary named \"session_keys\"\n",
"session_keys = [\n",
" dict(subject=\"subject6\", session_datetime=\"2021-06-02 14:04:22\"),\n",
" dict(subject=\"subject6\", session_datetime=\"2021-06-03 14:43:10\"),\n",
"]\n",
"\n",
"#Insert this dictionary in the Session table\n",
"session.Session.insert(session_keys, skip_duplicates=True)\n"
"# Insert this dictionary in the Session table\n",
"session.Session.insert(session_keys, skip_duplicates=True)"
]
},
{
Expand Down Expand Up @@ -1791,10 +1798,14 @@
"metadata": {},
"outputs": [],
"source": [
"recording_key = {'subject': 'subject6',\n",
" 'session_datetime': '2021-06-02 14:04:22',\n",
" 'recording_id': '1'}\n",
"model.VideoRecording.insert1({**recording_key, 'device': 'Camera1'}, skip_duplicates=True)"
"recording_key = {\n",
" \"subject\": \"subject6\",\n",
" \"session_datetime\": \"2021-06-02 14:04:22\",\n",
" \"recording_id\": \"1\",\n",
"}\n",
"model.VideoRecording.insert1(\n",
" {**recording_key, \"device\": \"Camera1\"}, skip_duplicates=True\n",
")"
]
},
{
Expand All @@ -1810,12 +1821,14 @@
"metadata": {},
"outputs": [],
"source": [
"video_files = [\"./example_data/inbox/from_top_tracking-DataJoint-2023-10-11/videos/train1.mp4\"]\n",
"video_files = [\n",
" \"./example_data/inbox/from_top_tracking-DataJoint-2023-10-11/videos/train1.mp4\"\n",
"]\n",
"\n",
"model.VideoRecording.File.insert({\n",
" **recording_key, \n",
" 'file_id': v_idx, \n",
" 'file_path': Path(f)} for v_idx, f in enumerate(video_files))"
"model.VideoRecording.File.insert(\n",
" {**recording_key, \"file_id\": v_idx, \"file_path\": Path(f)}\n",
" for v_idx, f in enumerate(video_files)\n",
")"
]
},
{
Expand Down Expand Up @@ -2054,7 +2067,7 @@
"metadata": {},
"outputs": [],
"source": [
"task_key = {**recording_key, 'model_name': 'from_top_tracking_model_test'}"
"task_key = {**recording_key, \"model_name\": \"from_top_tracking_model_test\"}"
]
},
{
Expand All @@ -2071,10 +2084,12 @@
"outputs": [],
"source": [
"model.PoseEstimationTask.insert1(\n",
" {**task_key,\n",
" 'task_mode': 'load',\n",
" 'pose_estimation_output_dir': './example_data/outbox/from_top_tracking-DataJoint-2023-10-11/videos/device_1_recording_1_model_from_top_tracking_100000_maxiters'\n",
" })"
" {\n",
" **task_key,\n",
" \"task_mode\": \"load\",\n",
" \"pose_estimation_output_dir\": \"./example_data/outbox/from_top_tracking-DataJoint-2023-10-11/videos/device_1_recording_1_model_from_top_tracking_100000_maxiters\",\n",
" }\n",
")"
]
},
{
Expand Down Expand Up @@ -2471,7 +2486,11 @@
"metadata": {},
"outputs": [],
"source": [
"df = (model.PoseEstimation.BodyPartPosition & task_key).fetch(format='frame').reset_index()"
"df = (\n",
" (model.PoseEstimation.BodyPartPosition & task_key)\n",
" .fetch(format=\"frame\")\n",
" .reset_index()\n",
")"
]
},
{
Expand Down Expand Up @@ -2836,7 +2855,7 @@
}
],
"source": [
"df = df.explode(['frame_index', 'x_pos', 'y_pos', 'likelihood']).reset_index()\n",
"df = df.explode([\"frame_index\", \"x_pos\", \"y_pos\", \"likelihood\"]).reset_index()\n",
"df"
]
},
Expand Down Expand Up @@ -2871,8 +2890,8 @@
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"head_data = df[df['body_part'] == 'head']\n",
"tail_data = df[df['body_part'] == 'tailbase']"
"head_data = df[df[\"body_part\"] == \"head\"]\n",
"tail_data = df[df[\"body_part\"] == \"tailbase\"]"
]
},
{
Expand All @@ -2892,18 +2911,18 @@
}
],
"source": [
"fig, axs = plt.subplots(2,1, figsize=(12, 4))\n",
"fig, axs = plt.subplots(2, 1, figsize=(12, 4))\n",
"\n",
"axs[0].set_title('x position - Head pose estimation')\n",
"axs[0].plot(head_data['x_pos'], label='x_pos')\n",
"axs[0].set_xlabel('time (frames)')\n",
"axs[0].set_ylabel('pos (pixels)')\n",
"axs[0].set_title(\"x position - Head pose estimation\")\n",
"axs[0].plot(head_data[\"x_pos\"], label=\"x_pos\")\n",
"axs[0].set_xlabel(\"time (frames)\")\n",
"axs[0].set_ylabel(\"pos (pixels)\")\n",
"axs[0].legend()\n",
"\n",
"axs[1].set_title('y position - Head pose estimation')\n",
"axs[1].plot(head_data['y_pos'], label='y_pos')\n",
"axs[1].set_xlabel('time (frames)')\n",
"axs[1].set_ylabel('pos (pixels)')\n",
"axs[1].set_title(\"y position - Head pose estimation\")\n",
"axs[1].plot(head_data[\"y_pos\"], label=\"y_pos\")\n",
"axs[1].set_xlabel(\"time (frames)\")\n",
"axs[1].set_ylabel(\"pos (pixels)\")\n",
"axs[1].legend()\n",
"\n",
"plt.tight_layout()\n",
Expand All @@ -2927,17 +2946,17 @@
}
],
"source": [
"fig, axs = plt.subplots(2,1, figsize=(12, 4))\n",
"axs[0].set_title('x position - Tailbase pose estimation')\n",
"axs[0].plot(head_data['x_pos'], label='x_pos',color='orange')\n",
"axs[0].set_xlabel('time (frames)')\n",
"axs[0].set_ylabel('pos (pixels)')\n",
"fig, axs = plt.subplots(2, 1, figsize=(12, 4))\n",
"axs[0].set_title(\"x position - Tailbase pose estimation\")\n",
"axs[0].plot(head_data[\"x_pos\"], label=\"x_pos\", color=\"orange\")\n",
"axs[0].set_xlabel(\"time (frames)\")\n",
"axs[0].set_ylabel(\"pos (pixels)\")\n",
"axs[0].legend()\n",
"\n",
"axs[1].set_title('y position - Tailbase pose estimation')\n",
"axs[1].plot(head_data['y_pos'], label='y_pos',color='orange')\n",
"axs[1].set_xlabel('time (frames)')\n",
"axs[1].set_ylabel('pos (pixels)')\n",
"axs[1].set_title(\"y position - Tailbase pose estimation\")\n",
"axs[1].plot(head_data[\"y_pos\"], label=\"y_pos\", color=\"orange\")\n",
"axs[1].set_xlabel(\"time (frames)\")\n",
"axs[1].set_ylabel(\"pos (pixels)\")\n",
"axs[1].legend()\n",
"\n",
"plt.tight_layout()\n",
Expand Down Expand Up @@ -2968,18 +2987,18 @@
}
],
"source": [
"fig, axs = plt.subplots(2,1, figsize=(6,10))\n",
"fig, axs = plt.subplots(2, 1, figsize=(6, 10))\n",
"\n",
"axs[0].set_title('Head pose estimation')\n",
"axs[0].plot(head_data['x_pos'], head_data['y_pos'],label='head',color='blue')\n",
"axs[0].set_xlabel('x position (pixels)')\n",
"axs[0].set_ylabel('y position (pixels)')\n",
"axs[0].set_title(\"Head pose estimation\")\n",
"axs[0].plot(head_data[\"x_pos\"], head_data[\"y_pos\"], label=\"head\", color=\"blue\")\n",
"axs[0].set_xlabel(\"x position (pixels)\")\n",
"axs[0].set_ylabel(\"y position (pixels)\")\n",
"axs[0].legend()\n",
"\n",
"axs[1].set_title('Tailbase pose estimation')\n",
"axs[1].plot(tail_data['x_pos'], tail_data['y_pos'], label='tailbase',color='orange')\n",
"axs[1].set_xlabel('x position (pixels)')\n",
"axs[1].set_ylabel('y position (pixels)')\n",
"axs[1].set_title(\"Tailbase pose estimation\")\n",
"axs[1].plot(tail_data[\"x_pos\"], tail_data[\"y_pos\"], label=\"tailbase\", color=\"orange\")\n",
"axs[1].set_xlabel(\"x position (pixels)\")\n",
"axs[1].set_ylabel(\"y position (pixels)\")\n",
"axs[1].legend()\n",
"\n",
"plt.tight_layout()\n",
Expand Down

0 comments on commit ded35b3

Please sign in to comment.