forked from julianje/ImageInference
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathVisualizePosterior_Short.R
128 lines (108 loc) · 5.73 KB
/
VisualizePosterior_Short.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
library(tidyverse)
library(gganimate)
library(RColorBrewer)
library(png)
library(grid)
mypng <- readPNG('~/Desktop/introduction_0.png')
g <- rasterGrob(mypng, interpolate=TRUE)
# Load everything ------------------------------------------
Threshhold <- 0.001 #Only plot trajectories with probability higher than this:
TrialName = "UN_DX_0"
directory = "~/Documents/Projects/Models/ImageInference/ImageInference/ModelPredictions/"
setwd(directory)
statesfile = paste(directory,TrialName,"_States_Posterior.csv",sep="")
goalfile = paste(directory,TrialName,"_Goal_Posterior.csv",sep="")
timefile = paste(directory,TrialName,"_Time_Estimates.csv",sep="")
posterior <- read_csv(statesfile,col_types = cols())
goals <- read_csv(goalfile, col_types = cols())
time <- read_csv(timefile, col_types = cols())
mapheight = posterior$MapHeight[1]
mapwidth = posterior$MapWidth[1]
# Visualize goal inference -------------------------------
goals %>% ggplot(aes(Goal,Probability))+geom_bar(stat="identity")+
theme_classic()+scale_y_continuous(limits=c(0,1.01))+
theme(aspect.ratio=1)+ggtitle(TrialName)
# Visualize time -----------------------------------
time %>% group_by(Time) %>% summarise(Probability=sum(Probability)) %>%
ggplot(aes(Time,Probability))+geom_bar(stat="identity")+
theme_classic()+scale_y_continuous(limits=c(0,1.01),breaks=c(0,0.5,1))+
scale_x_continuous(limits=c(0,100))+
theme(aspect.ratio=1)+ggtitle(paste("Time estimate: ",TrialName,sep=""))
# Visualize inferred path --------------------------------
scene_x = posterior$Scene[1] %% mapwidth + 1
scene_y = ceiling(posterior$Scene[1]/mapwidth)
States <- posterior %>%
filter(Probability>0) %>%
rownames_to_column("Id") %>% gather(Step,State,Obs0:(ncol(posterior)+1)) %>%
separate(Step,into=c("Discard","Time"),sep=3) %>%
dplyr::select(-Discard) %>%
mutate(Time=as.numeric(Time)) %>%
arrange(Id,Time) %>%
mutate(y=ceiling(State/MapWidth),
x=State%%MapWidth+1) %>% mutate(Id=as.numeric(Id))
myPalette <- colorRampPalette(rev(brewer.pal(11,"Spectral")))
sc <- scale_colour_gradientn(colours = myPalette(100), limits=c(0,15))
# color path based on time:
States %>% filter(Probability>=Threshhold) %>% arrange(Probability,Time) %>%
ggplot()+
geom_path(aes(x=x,y=y,group=Id,color=Time,alpha=Probability+0.75),
position=position_jitter(height=0.25,width=0.25))+
geom_point(aes(x=x,y=y,color=Time,alpha=Probability+0.75))+
scale_x_continuous("",limits=c(0,mapwidth+1))+
scale_y_reverse("",limits=c(mapheight+1,1))+
theme_void()+theme(legend.title=element_blank(),
legend.position = "none",
axis.title.x=element_blank(),
axis.text.x=element_blank(),
axis.ticks.x=element_blank(),
axis.title.y=element_blank(),
axis.text.y=element_blank(),
axis.ticks.y=element_blank())+sc+coord_fixed()+
geom_point(aes(x=scene_x,y=scene_y),size = 6, pch = 19,colour="#000000")+
geom_point(data=data.frame(xv=c(2.5,2.5,9.5,9.5),yv=c(2.5,9.5,2.5,9.5)),aes(xv,yv),size=0.1)+
geom_rect(data=data.frame(x1=c(2.5,2.5,8.5),x2=c(3.5,3.5,9.5),y1=c(2.5,8.5,2.5),y2=c(3.5,9.5,3.5),id=c("a","b","c")),
aes(xmin=x1,xmax=x2,ymin=y1,ymax=y2,fill=id),alpha=1)+
scale_fill_manual(values=c("#FF8F66","#8293FF","#7AB532"))+
annotation_custom(g, xmin=2.5, xmax=9.5, ymin=-Inf, ymax=0.5)
#annotation_raster(mypng, ymin = 2.5,ymax= 9.5,xmin = 2.5,xmax = 9.5)
# color path based on probability
scb <- scale_colour_gradientn(colours = myPalette(100), limits=c(min(States$Probability),max(States$Probability)))
States %>% filter(Probability>=Threshhold) %>% arrange(Probability,Time) %>%
ggplot()+
geom_path(aes(x=x,y=y,group=Id,color=Probability,alpha=Probability+0.75),
position=position_jitter(height=0.2,width=0.2))+
geom_point(aes(x=x,y=y),colour="gray")+
scale_x_continuous("",limits=c(0,mapwidth+1))+
scale_y_reverse("",limits=c(mapheight+1,1))+
theme_void()+theme(legend.title=element_blank(),
legend.position = "none",
axis.title.x=element_blank(),
axis.text.x=element_blank(),
axis.ticks.x=element_blank(),
axis.title.y=element_blank(),
axis.text.y=element_blank(),
axis.ticks.y=element_blank())+scb+coord_fixed()+
geom_point(aes(x=scene_x,y=scene_y),size = 5, colour="#22231E")+
geom_point(data=data.frame(xv=c(2.5,2.5,9.5,9.5),yv=c(2.5,9.5,2.5,9.5)),aes(xv,yv),size=0.1)+
geom_rect(data=data.frame(x1=c(2.5,2.5,8.5),x2=c(3.5,3.5,9.5),y1=c(2.5,8.5,2.5),y2=c(3.5,9.5,3.5),id=c("a","b","c")),
aes(xmin=x1,xmax=x2,ymin=y1,ymax=y2,fill=id),alpha=1)+
scale_fill_manual(values=c("#FF8F66","#8293FF","#7AB532"))
# Visualize starting pont -------------------------------
Entrances <- States %>% filter(Time==0) %>% group_by(State) %>%
summarise(Probability=sum(Probability)) %>% rename(Entrance=State)
Entrances %>%
ggplot(aes(factor(Entrance),Probability))+geom_bar(stat="identity")+
theme_classic()+scale_x_discrete("Starting point")+
scale_y_continuous("Probability",limits=c(0,1.01))+
ggtitle("Inferred starting point")
write.csv(Entrances, paste(TrialName,"_Entrance.csv"),quote=F,row.names = F)
# Visualize density of actions over time:
scc <- scale_fill_gradientn(colours = myPalette(100), limits=c(0,1))
States %>% group_by(Time,y,x) %>% summarise(Probability=sum(Probability)) %>%
ggplot(aes(x,y,fill=Probability))+geom_tile()+facet_wrap(~Time)+
scale_x_continuous("",limits=c(2,9.51))+
scale_y_reverse("",limits=c(10.5,2))+scc+
geom_hline(yintercept=2:9+0.5)+
geom_vline(xintercept=2:9+0.5)+
geom_text(aes(x,y,label=round(Probability,2)))+
theme_classic()