This repository has been archived by the owner on Aug 24, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathanalyse_training_validation_loss.Rmd
98 lines (83 loc) · 2.36 KB
/
analyse_training_validation_loss.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
---
title: "Training/validation loss"
output: html_notebook
---
```{r packages, include=FALSE, echo=FALSE}
if (!require(pacman)) install.packages("pacman")
pacman::p_load(gridExtra, here, tidyverse)
```
```{r define_doccano_colours}
doccano_colour_scheme_technical <- c(
'ATTENTION' = '#f44e3b',
'EVENT' = '#0062b1',
'LOC' = '#7b64ff',
'MISC' = '#a79c4f',
'ORG' = '#16a5a5',
'PER' = '#fcdc00',
'Personal Bookmark' = '#fda1ff',
'POSTCORR' = '#9f0500',
'TIME' = '#73d8ff',
'?' = '#cccccc'
)
doccano_colour_scheme_named <- c(
'Event' = '#0062b1',
'Location' = '#7b64ff',
'Miscellaneous' = '#a79c4f',
'Organisation' = '#16a5a5',
'Person' = '#fcdc00',
'Time' = '#73d8ff'
)
```
```{r, warning=FALSE}
data_file <- here::here("training-losses.xlsx")
labels <- data_file %>% readxl::excel_sheets()
# find y max
for (label in labels) {
label_max <-
data_file %>%
readxl::read_excel(sheet = label, col_types = "numeric") %>%
select(Epoch, `Training Loss`, `Validation Loss`) %>%
select(-Epoch) %>%
max()
if (!exists("y_max")) {
y_max <- label_max
} else {
y_max <- max(y_max, label_max)
}
}
for (label in labels) {
data <-
data_file %>%
readxl::read_excel(sheet = label, col_types = "numeric") %>%
select(Epoch, `Training Loss`, `Validation Loss`)
line_plot <-
data %>%
ggplot(aes(x = Epoch)) +
geom_line(aes(y = `Training Loss`, color = "Training Loss"), size = 1) +
geom_line(aes(y = `Validation Loss`, color = "Validation Loss"), size = 1) +
scale_color_manual(values = c("Training Loss" = "dodgerblue", "Validation Loss" = "tomato")) +
labs(title = paste("Training and validation loss for", label, "over epochs"),
x = "Epoch",
y = "Loss",
color = "Legend") +
ylim(0, y_max) +
theme_bw()
table_plot <-
data %>%
mutate_at(
vars(`Training Loss`, `Validation Loss`),
~format(round(., 5), nsmall = 5)
) %>%
tableGrob(rows = NULL)
plot <- grid.arrange(
line_plot, table_plot,
ncol = 1, heights = c(2, 1)
)
print(plot)
ggsave(
filename = here::here("plots", paste0("loss_", label, ".png")),
plot = plot,
create.dir = TRUE
)
}
```