-
Notifications
You must be signed in to change notification settings - Fork 3
/
samlsp.go
148 lines (132 loc) · 3.68 KB
/
samlsp.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
// based on https://github.com/42wim/crewjam-saml
// Package samlsp provides helpers that can be used to protect web
// services using SAML.
package samlplugin
import (
"crypto/rsa"
"crypto/x509"
"encoding/xml"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"time"
"github.com/42wim/crewjam-saml"
"github.com/42wim/crewjam-saml/logger"
)
const defaultTokenMaxAge = time.Hour
// Options represents the parameters for creating a new middleware
type Options struct {
URL url.URL
Key *rsa.PrivateKey
Logger logger.Interface
Certificate *x509.Certificate
AllowIDPInitiated bool
IDPMetadata *saml.EntityDescriptor
IDPMetadataURL *url.URL
HTTPClient *http.Client
CookieMaxAge time.Duration
CookieSecure bool
ForceAuthn bool
CookieName string
EnableSessions bool
DbURI string
}
// New creates a new SAMLPlugin
func New(opts Options) (*SAMLPlugin, error) {
metadataURL := opts.URL
metadataURL.Path = metadataURL.Path + "/saml/metadata"
acsURL := opts.URL
acsURL.Path = acsURL.Path + "/saml/acs"
sloURL := opts.URL
sloURL.Path = sloURL.Path + "/saml/slo"
logr := opts.Logger
if logr == nil {
logr = logger.DefaultLogger
}
tokenMaxAge := opts.CookieMaxAge
if opts.CookieMaxAge == 0 {
tokenMaxAge = defaultTokenMaxAge
}
s := &SAMLPlugin{
ServiceProvider: saml.ServiceProvider{
Key: opts.Key,
Logger: logr,
Certificate: opts.Certificate,
MetadataURL: metadataURL,
AcsURL: acsURL,
SloURL: sloURL,
IDPMetadata: opts.IDPMetadata,
ForceAuthn: &opts.ForceAuthn,
},
AllowIDPInitiated: opts.AllowIDPInitiated,
TokenMaxAge: tokenMaxAge,
EnableSessions: opts.EnableSessions,
Db: NewDB(opts.DbURI),
}
cookieStore := ClientCookies{
ServiceProvider: &s.ServiceProvider,
Name: opts.CookieName,
//Domain: opts.URL.Host,
Secure: opts.CookieSecure,
}
s.ClientState = &cookieStore
s.ClientToken = &cookieStore
s.cacheMap = make(map[string]Session)
// fetch the IDP metadata if needed.
if opts.IDPMetadataURL == nil {
return s, nil
}
c := opts.HTTPClient
if c == nil {
c = http.DefaultClient
}
req, err := http.NewRequest("GET", opts.IDPMetadataURL.String(), nil)
if err != nil {
return nil, err
}
// Some providers (like OneLogin) do not work properly unless the User-Agent header is specified.
// Setting the user agent prevents the 403 Forbidden errors.
req.Header.Set("User-Agent", "Golang; github.com/42wim/crewjam-saml")
for i := 0; true; i++ {
resp, err := c.Do(req)
if err == nil && resp.StatusCode != http.StatusOK {
err = fmt.Errorf("%d %s", resp.StatusCode, resp.Status)
}
var data []byte
if err == nil {
data, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
}
if err != nil {
if i > 10 {
return nil, err
}
logr.Printf("ERROR: %s: %s (will retry)", opts.IDPMetadataURL, err)
time.Sleep(5 * time.Second)
continue
}
entity := &saml.EntityDescriptor{}
err = xml.Unmarshal(data, entity)
// this comparison is ugly, but it is how the error is generated in encoding/xml
if err != nil && err.Error() == "expected element type <EntityDescriptor> but have <EntitiesDescriptor>" {
entities := &saml.EntitiesDescriptor{}
if err := xml.Unmarshal(data, entities); err != nil {
return nil, err
}
err = fmt.Errorf("no entity found with IDPSSODescriptor")
for i, e := range entities.EntityDescriptors {
if len(e.IDPSSODescriptors) > 0 {
entity = &entities.EntityDescriptors[i]
err = nil
}
}
}
if err != nil {
return nil, err
}
s.ServiceProvider.IDPMetadata = entity
return s, nil
}
panic("unreachable")
}