Skip to content

Commit

Permalink
fixes for the following issues:
Browse files Browse the repository at this point in the history
- #41
- #40
  • Loading branch information
fabriziosalmi committed Jan 29, 2025
1 parent c99e875 commit 9223d33
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 94 deletions.
6 changes: 6 additions & 0 deletions blacklist.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ func extractIP(remoteAddr string, logger *zap.Logger) string {
return host
}

// LoadIPBlacklistFromFile loads IP addresses from a file into the provided map.
// LoadIPBlacklistFromFile loads IP addresses from a file into the provided map.
func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[string]struct{}) error {
bl.logger.Debug("Loading IP blacklist", zap.String("path", path))
Expand All @@ -141,6 +142,7 @@ func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[
scanner := bufio.NewScanner(file)
validEntries := 0
totalLines := 0
invalidEntries := 0

for scanner.Scan() {
totalLines++
Expand All @@ -156,6 +158,9 @@ func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[
zap.Int("line", totalLines),
zap.String("entry", line),
)
invalidEntries++
// If you want the entire load to fail if any single IP entry is invalid, uncomment the line below
// return fmt.Errorf("failed to add IP entry %s : %w", line, err)
} else {
validEntries++
}
Expand All @@ -169,6 +174,7 @@ func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[
bl.logger.Info("IP blacklist loaded",
zap.String("path", path),
zap.Int("valid_entries", validEntries),
zap.Int("invalid_entries", invalidEntries),
zap.Int("total_lines", totalLines),
)
return nil
Expand Down
97 changes: 38 additions & 59 deletions caddywaf.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"os"
"strings"
Expand All @@ -29,6 +28,7 @@ var (
_ caddy.Provisioner = (*Middleware)(nil)
_ caddyhttp.MiddlewareHandler = (*Middleware)(nil)
_ caddyfile.Unmarshaler = (*Middleware)(nil)
_ caddy.Validator = (*Middleware)(nil)
)

// ==================== Initialization and Setup ====================
Expand Down Expand Up @@ -209,19 +209,18 @@ func (m *Middleware) Provision(ctx caddy.Context) error {
}

// Load IP blacklist
m.ipBlacklist = NewCIDRTrie()
m.logger.Debug("ipBlacklist initialized in Provision", zap.Bool("isNil", m.ipBlacklist == nil))
if m.IPBlacklistFile != "" {
err = m.loadIPBlacklistIntoMap(m.IPBlacklistFile, m.ipBlacklist)
m.ipBlacklist = NewCIDRTrie()
err = m.loadIPBlacklist(m.IPBlacklistFile, m.ipBlacklist)
if err != nil {
return fmt.Errorf("failed to load IP blacklist: %w", err)
}
}

// Load DNS blacklist
m.dnsBlacklist = make(map[string]struct{}) // Changed to map[string]struct{}
if m.DNSBlacklistFile != "" {
err = m.blacklistLoader.LoadDNSBlacklistFromFile(m.DNSBlacklistFile, m.dnsBlacklist)
m.dnsBlacklist = make(map[string]struct{})
err = m.loadDNSBlacklist(m.DNSBlacklistFile, m.dnsBlacklist)
if err != nil {
return fmt.Errorf("failed to load DNS blacklist: %w", err)
}
Expand Down Expand Up @@ -414,25 +413,21 @@ func (m *Middleware) ReloadConfig() error {
defer m.mu.Unlock()

m.logger.Info("Reloading WAF configuration")

newIPBlacklist := NewCIDRTrie()
if m.IPBlacklistFile != "" {
if err := m.loadIPBlacklistIntoMap(m.IPBlacklistFile, newIPBlacklist); err != nil {
newIPBlacklist := NewCIDRTrie()
if err := m.loadIPBlacklist(m.IPBlacklistFile, newIPBlacklist); err != nil {
m.logger.Error("Failed to reload IP blacklist", zap.String("file", m.IPBlacklistFile), zap.Error(err))
return fmt.Errorf("failed to reload IP blacklist: %v", err)
}
} else {
m.logger.Debug("No IP blacklist file specified, skipping reload")
m.ipBlacklist = newIPBlacklist
}

newDNSBlacklist := make(map[string]struct{})
if m.DNSBlacklistFile != "" {
if err := m.loadDNSBlacklistIntoMap(m.DNSBlacklistFile, newDNSBlacklist); err != nil {
newDNSBlacklist := make(map[string]struct{})
if err := m.loadDNSBlacklist(m.DNSBlacklistFile, newDNSBlacklist); err != nil {
m.logger.Error("Failed to reload DNS blacklist", zap.String("file", m.DNSBlacklistFile), zap.Error(err))
return fmt.Errorf("failed to reload DNS blacklist: %v", err)
}
} else {
m.logger.Debug("No DNS blacklist file specified, skipping reload")
m.dnsBlacklist = newDNSBlacklist
}

// Call the external loadRules function
Expand All @@ -441,62 +436,38 @@ func (m *Middleware) ReloadConfig() error {
return fmt.Errorf("failed to reload rules: %v", err)
}

m.ipBlacklist = newIPBlacklist
m.dnsBlacklist = newDNSBlacklist

m.logger.Info("WAF configuration reloaded successfully")

return nil
}

func (m *Middleware) loadIPBlacklistIntoMap(path string, blacklistMap *CIDRTrie) error {
content, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("failed to read IP blacklist file: %v", err)
func (m *Middleware) loadIPBlacklist(path string, blacklistMap *CIDRTrie) error {
if _, err := os.Stat(path); os.IsNotExist(err) {
m.logger.Warn("Skipping IP blacklist load, file does not exist", zap.String("file", path))
return nil
}

lines := strings.Split(string(content), "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}

if !strings.Contains(line, "/") {
// Handle single IP addresses
ip := net.ParseIP(line)
if ip == nil {
m.logger.Warn("Skipping invalid IP address format in blacklist", zap.String("address", line))
continue
}

if ip.To4() != nil {
line = line + "/32"
} else {
line = line + "/128"
}
}
blacklist := make(map[string]struct{})
err := m.blacklistLoader.LoadIPBlacklistFromFile(path, blacklist)
if err != nil {
return fmt.Errorf("failed to load IP blacklist: %w", err)
}

if err := blacklistMap.Insert(line); err != nil {
m.logger.Warn("Failed to insert CIDR into trie", zap.String("cidr", line), zap.Error(err))
}
// Convert the map to CIDRTrie
for ip := range blacklist {
blacklistMap.Insert(ip)
}
return nil
}

func (m *Middleware) loadDNSBlacklistIntoMap(path string, blacklistMap map[string]struct{}) error {
content, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("failed to read DNS blacklist file: %v", err)
func (m *Middleware) loadDNSBlacklist(path string, blacklistMap map[string]struct{}) error {
if _, err := os.Stat(path); os.IsNotExist(err) {
m.logger.Warn("Skipping DNS blacklist load, file does not exist", zap.String("file", path))
return nil
}

lines := strings.Split(string(content), "\n")
for _, line := range lines {
line = strings.ToLower(strings.TrimSpace(line))
if line == "" || strings.HasPrefix(line, "#") {
continue
}
blacklistMap[line] = struct{}{} // Changed to struct{}{}
err := m.blacklistLoader.LoadDNSBlacklistFromFile(path, blacklistMap)
if err != nil {
return fmt.Errorf("failed to load DNS blacklist: %w", err)
}
return nil
}
Expand Down Expand Up @@ -570,3 +541,11 @@ func (m *Middleware) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
}
return m.configLoader.UnmarshalCaddyfile(d, m)
}

// Validate implements caddy.Validator.
func (m *Middleware) Validate() error {
if m.logLevel == 0 {
m.logLevel = zapcore.InfoLevel // Default log level
}
return nil
}
50 changes: 25 additions & 25 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ import (
"go.uber.org/zap/zapcore"
)

type ContextKeyLogId string
type ContextKeyRule string

// ServeHTTP implements caddyhttp.Handler.
func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
logID := uuid.New().String()
Expand Down Expand Up @@ -41,7 +44,7 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd
err := next.ServeHTTP(recorder, r)

// Phase 3: Response Header analysis
if m.isResponseHeaderPhaseBlocked(recorder, r, 3, state) {
if m.isPhaseBlocked(recorder, r, 3, state) {
return nil // Request blocked in Phase 3, short-circuit
}

Expand Down Expand Up @@ -83,17 +86,6 @@ func (m *Middleware) isPhaseBlocked(w http.ResponseWriter, r *http.Request, phas
return false
}

// isResponseHeaderPhaseBlocked encapsulates the response header phase handling and blocking check logic.
func (m *Middleware) isResponseHeaderPhaseBlocked(recorder *responseRecorder, r *http.Request, phase int, state *WAFState) bool {
m.handlePhase(recorder, r, phase, state)
if state.Blocked {
m.incrementBlockedRequestsMetric()
recorder.WriteHeader(state.StatusCode)
return true
}
return false
}

// logRequestStart logs the start of WAF evaluation.
func (m *Middleware) logRequestStart(r *http.Request, logID string) {
m.logger.Info("WAF request evaluation started",
Expand Down Expand Up @@ -126,15 +118,20 @@ func (m *Middleware) initializeWAFState() *WAFState {
func (m *Middleware) handleResponseBodyPhase(recorder *responseRecorder, r *http.Request, state *WAFState) {
// No need to check if recorder.body is nil here, it's always initialized in NewResponseRecorder
body := recorder.BodyString()
m.logger.Debug("Response body captured for Phase 4 analysis", zap.String("log_id", r.Context().Value(ContextKeyLogId("logID")).(string)))
logID, ok := r.Context().Value(ContextKeyLogId("logID")).(string)

if !ok {
m.logger.Error("Log ID not found in context")
return
}
m.logger.Debug("Response body captured for Phase 4 analysis", zap.String("log_id", logID))

for _, rule := range m.Rules[4] {
if rule.regex.MatchString(body) {
m.processRuleMatch(recorder, r, &rule, body, state)
if state.Blocked {
m.incrementBlockedRequestsMetric()
if m.processRuleMatch(recorder, r, &rule, body, state) {
return
}

}
}
}
Expand Down Expand Up @@ -191,9 +188,16 @@ func (m *Middleware) copyResponse(w http.ResponseWriter, recorder *responseRecor
}
w.WriteHeader(recorder.StatusCode())

logID, ok := r.Context().Value(ContextKeyLogId("logID")).(string)

if !ok {
m.logger.Error("Log ID not found in context")
return
}

_, err := w.Write(recorder.body.Bytes()) // Copy body from recorder to original writer
if err != nil {
m.logger.Error("Failed to write recorded response body to client", zap.Error(err), zap.String("log_id", r.Context().Value("logID").(string)))
m.logger.Error("Failed to write recorded response body to client", zap.Error(err), zap.String("log_id", logID))
}
}

Expand Down Expand Up @@ -221,7 +225,6 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule", r.RemoteAddr,
zap.String("message", "Request blocked by country"),
)
m.logger.Debug("Country blocking phase completed - blocked by country")
return
}
m.logger.Debug("Country blocking phase completed - not blocked")
Expand All @@ -235,7 +238,6 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
m.blockRequest(w, r, state, http.StatusTooManyRequests, "rate_limit", "rate_limit_rule", r.RemoteAddr,
zap.String("message", "Request blocked by rate limit"),
)
m.logger.Debug("Rate limiting phase completed - blocked by rate limit")
return
}
m.logger.Debug("Rate limiting phase completed - not blocked")
Expand All @@ -254,7 +256,6 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
m.blockRequest(w, r, state, http.StatusForbidden, "ip_blacklist", "ip_blacklist_rule", firstIP,
zap.String("message", "Request blocked by IP blacklist"),
)
m.logger.Debug("IP blacklist phase completed - blocked")
return
}
} else {
Expand All @@ -269,7 +270,6 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
m.blockRequest(w, r, state, http.StatusForbidden, "ip_blacklist", "ip_blacklist_rule", r.RemoteAddr,
zap.String("message", "Request blocked by IP blacklist"),
)
m.logger.Debug("IP blacklist phase completed - blocked")
return
}
}
Expand All @@ -281,7 +281,6 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
zap.String("message", "Request blocked by DNS blacklist"),
zap.String("host", r.Host),
)
m.logger.Debug("DNS blacklist phase completed - blocked")
return
}

Expand All @@ -292,6 +291,7 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
}

m.logger.Debug("Starting rule evaluation for phase", zap.Int("phase", phase), zap.Int("rule_count", len(rules)))

for _, rule := range rules {
m.logger.Debug("Processing rule", zap.String("rule_id", string(rule.ID)), zap.Int("target_count", len(rule.Targets)))

Expand Down Expand Up @@ -338,16 +338,16 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
)
if phase == 3 || phase == 4 {
if recorder, ok := w.(*responseRecorder); ok {
if !m.processRuleMatch(recorder, r, &rule, value, state) {
if m.processRuleMatch(recorder, r, &rule, value, state) {
return // Stop processing if the rule match indicates blocking
}
} else {
if !m.processRuleMatch(w, r, &rule, value, state) {
if m.processRuleMatch(w, r, &rule, value, state) {
return // Stop processing if the rule match indicates blocking
}
}
} else {
if !m.processRuleMatch(w, r, &rule, value, state) {
if m.processRuleMatch(w, r, &rule, value, state) {
return // Stop processing if the rule match indicates blocking
}
}
Expand Down
2 changes: 1 addition & 1 deletion logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ var sensitiveKeys = []string{
"routing", // Routing number
"mfa", // Multi-factor authentication code
"otp", // One-time password
"code", // Generic code
//"code", // Generic code <------ REMOVED THIS
}

var sensitiveKeysMutex sync.RWMutex // Add mutex for thread safety when modifying
Expand Down
3 changes: 0 additions & 3 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ type RequestValueExtractor struct {
redactSensitiveData bool // Add this field
}

// Define a custom type for context keys
type ContextKeyLogId string

// Extraction Target Constants - Improved Readability and Maintainability
const (
TargetMethod = "METHOD"
Expand Down
Loading

0 comments on commit 9223d33

Please sign in to comment.