Skip to content

Commit

Permalink
logging improved to adhere to Caddy standards.
Browse files Browse the repository at this point in the history
  • Loading branch information
fabriziosalmi committed Jan 15, 2025
1 parent afc3c1d commit 2c17527
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 38 deletions.
4 changes: 2 additions & 2 deletions Caddyfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
log {
output stdout
format console
level INFO
level DEBUG
}

handle {
Expand All @@ -34,7 +34,7 @@
# rule_file rules/wordpress.json
ip_blacklist_file ip_blacklist.txt
dns_blacklist_file dns_blacklist.txt
log_severity info
log_severity debug
log_json
log_path debug.json
# redact_sensitive_data
Expand Down
8 changes: 6 additions & 2 deletions blacklist.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (bl *BlacklistLoader) LoadDNSBlacklistFromFile(path string, dnsBlacklist ma

// isIPBlacklisted checks if the given IP address is in the blacklist.
func (m *Middleware) isIPBlacklisted(remoteAddr string) bool {
ipStr := extractIP(remoteAddr)
ipStr := extractIP(remoteAddr, m.logger) // Pass the logger here
if ipStr == "" {
return false
}
Expand Down Expand Up @@ -121,9 +121,13 @@ func (m *Middleware) isDNSBlacklisted(host string) bool {
}

// extractIP extracts the IP address from a remote address string.
func extractIP(remoteAddr string) string {
func extractIP(remoteAddr string, logger *zap.Logger) string {
host, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
logger.Debug("Failed to extract IP from remote address, using full address",
zap.String("remoteAddr", remoteAddr),
zap.Error(err),
)
return remoteAddr // Assume the input is already an IP address
}
return host
Expand Down
6 changes: 3 additions & 3 deletions caddywaf.go
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
blocked, err := m.isCountryInList(r.RemoteAddr, m.CountryBlock.CountryList, m.CountryBlock.geoIP)
if err != nil {
m.logRequest(zapcore.ErrorLevel, "Failed to check country block",
zap.String("ip", r.RemoteAddr),
r,
zap.Error(err),
)
m.blockRequest(w, r, state, http.StatusForbidden,
Expand All @@ -673,8 +673,8 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i

if phase == 1 && m.rateLimiter != nil {
m.logger.Debug("Starting rate limiting phase")
ip := extractIP(r.RemoteAddr)
path := r.URL.Path // Get the request path
ip := extractIP(r.RemoteAddr, m.logger) // Pass the logger here
path := r.URL.Path // Get the request path
if m.rateLimiter.isRateLimited(ip, path) {
m.blockRequest(w, r, state, http.StatusTooManyRequests,
zap.String("message", "Request blocked by rate limit"),
Expand Down
88 changes: 82 additions & 6 deletions logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package caddywaf

import (
"net/http"
"strings"
"time"

Expand All @@ -10,11 +11,16 @@ import (
"go.uber.org/zap/zapcore"
)

func (m *Middleware) logRequest(level zapcore.Level, msg string, fields ...zap.Field) {
func (m *Middleware) logRequest(level zapcore.Level, msg string, r *http.Request, fields ...zap.Field) {
if m.logger == nil {
return
}

// Debug: Print the incoming fields
m.logger.Debug("Incoming fields to logRequest",
zap.Any("fields", fields),
)

// Extract log ID or generate a new one
var logID string
var newFields []zap.Field
Expand All @@ -36,9 +42,32 @@ func (m *Middleware) logRequest(level zapcore.Level, msg string, fields ...zap.F
// Append log_id explicitly to newFields
newFields = append(newFields, zap.String("log_id", logID))

// Attach common request metadata
commonFields := m.getCommonLogFields(newFields)
newFields = append(newFields, commonFields...)
// Debug: Print the newFields before calling getCommonLogFields
m.logger.Debug("New fields before getCommonLogFields",
zap.Any("newFields", newFields),

Check failure

Code scanning / CodeQL

Clear-text logging of sensitive information High

Sensitive data returned by HTTP request headers
flows to a logging call.
Sensitive data returned by HTTP request headers
flows to a logging call.
)

// Attach common request metadata only if not already set
commonFields := m.getCommonLogFields(r, newFields)
for _, commonField := range commonFields {
// Check if the field is already set in newFields
fieldExists := false
for _, existingField := range newFields {
if existingField.Key == commonField.Key {
fieldExists = true
break
}
}
// Add the common field only if it doesn't already exist
if !fieldExists {
newFields = append(newFields, commonField)
}
}

// Debug: Print the newFields after calling getCommonLogFields
m.logger.Debug("New fields after getCommonLogFields",
zap.Any("newFields", newFields),

Check failure

Code scanning / CodeQL

Clear-text logging of sensitive information High

Sensitive data returned by HTTP request headers
flows to a logging call.
Sensitive data returned by HTTP request headers
flows to a logging call.
)

// Determine the log level if unset
if m.logLevel == 0 {
Expand Down Expand Up @@ -68,11 +97,21 @@ func (m *Middleware) logRequest(level zapcore.Level, msg string, fields ...zap.F
}
}

func (m *Middleware) getCommonLogFields(fields []zap.Field) []zap.Field {
func (m *Middleware) getCommonLogFields(r *http.Request, fields []zap.Field) []zap.Field {
// Debug: Print the incoming fields
m.logger.Debug("Incoming fields to getCommonLogFields",
zap.Any("fields", fields),

Check failure

Code scanning / CodeQL

Clear-text logging of sensitive information High

Sensitive data returned by HTTP request headers
flows to a logging call.
Sensitive data returned by HTTP request headers
flows to a logging call.
)

// Extract or assign default values for metadata fields
var sourceIP, userAgent, requestMethod, requestPath, queryParams string
var sourceIP string
var userAgent string
var requestMethod string
var requestPath string
var queryParams string
var statusCode int

// Extract values from the incoming fields
for _, field := range fields {
switch field.Key {
case "source_ip":
Expand All @@ -90,6 +129,33 @@ func (m *Middleware) getCommonLogFields(fields []zap.Field) []zap.Field {
}
}

// If values are not provided in the fields, extract them from the request
if sourceIP == "" && r != nil {
sourceIP = r.RemoteAddr
}
if userAgent == "" && r != nil {
userAgent = r.UserAgent()
}
if requestMethod == "" && r != nil {
requestMethod = r.Method
}
if requestPath == "" && r != nil {
requestPath = r.URL.Path
}
if queryParams == "" && r != nil {
queryParams = r.URL.RawQuery
}

// Debug: Print the extracted values
m.logger.Debug("Extracted values in getCommonLogFields",
zap.String("source_ip", sourceIP),
zap.String("user_agent", userAgent),
zap.String("request_method", requestMethod),
zap.String("request_path", requestPath),
zap.String("query_params", queryParams),
zap.Int("status_code", statusCode),
)

// Default values for missing fields
if sourceIP == "" {
sourceIP = "unknown"
Expand All @@ -104,6 +170,16 @@ func (m *Middleware) getCommonLogFields(fields []zap.Field) []zap.Field {
requestPath = "unknown"
}

// Debug: Print the final values after applying defaults
m.logger.Debug("Final values after applying defaults",
zap.String("source_ip", sourceIP),
zap.String("user_agent", userAgent),
zap.String("request_method", requestMethod),
zap.String("request_path", requestPath),
zap.String("query_params", queryParams),
zap.Int("status_code", statusCode),
)

// Redact query parameters if required
if m.RedactSensitiveData {
queryParams = m.redactQueryParams(queryParams)
Expand Down
59 changes: 49 additions & 10 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
"github.com/google/uuid"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)

// RequestValueExtractor struct
Expand Down Expand Up @@ -158,6 +159,7 @@ func (rve *RequestValueExtractor) extractSingleValue(target string, r *http.Requ
return "", fmt.Errorf("header '%s' not found for target: %s", headerName, target)
}
unredactedValue = headerValue

// Dynamic Response Header Extraction (Phase 3)
case strings.HasPrefix(target, "RESPONSE_HEADERS:"):
if w == nil {
Expand Down Expand Up @@ -229,6 +231,7 @@ func (rve *RequestValueExtractor) extractSingleValue(target string, r *http.Requ
rve.logger.Debug("Failed to extract value from JSON path", zap.String("target", target), zap.String("path", jsonPath), zap.Error(err))
return "", fmt.Errorf("failed to extract from JSON path '%s': %w", jsonPath, err)
}

// New cases start here:
case target == "CONTENT_TYPE":
unredactedValue = r.Header.Get("Content-Type")
Expand Down Expand Up @@ -259,7 +262,7 @@ func (rve *RequestValueExtractor) extractSingleValue(target string, r *http.Requ
return "", fmt.Errorf("unknown extraction target: %s", target)
}

// Redact sensitive fields (unchanged)
// Redact sensitive fields before returning the value
value := unredactedValue
if rve.redactSensitiveData {
sensitiveTargets := []string{"password", "token", "apikey", "authorization", "secret"}
Expand All @@ -270,62 +273,98 @@ func (rve *RequestValueExtractor) extractSingleValue(target string, r *http.Requ
}
}
}

// Log the extracted value (redacted if necessary)
rve.logger.Debug("Extracted value",
zap.String("target", target),
zap.String("value", value), // Now logging the potentially redacted value
zap.String("value", value), // Log the potentially redacted value

Check failure

Code scanning / CodeQL

Clear-text logging of sensitive information High

Sensitive data returned by HTTP request headers
flows to a logging call.
Sensitive data returned by HTTP request headers
flows to a logging call.
)
return unredactedValue, nil // Return the unredacted value for rule matching

// Return the unredacted value for rule matching
return unredactedValue, nil
}

// Helper function for JSON path extraction.
func (rve *RequestValueExtractor) extractJSONPath(jsonStr string, jsonPath string) (string, error) {
// Validate input JSON string
if jsonStr == "" {
return "", fmt.Errorf("json string is empty")
}

// Validate JSON path
if jsonPath == "" {
return "", fmt.Errorf("json path is empty")
}

// Unmarshal JSON string into an interface{}
var jsonData interface{}
if err := json.Unmarshal([]byte(jsonStr), &jsonData); err != nil {
return "", fmt.Errorf("failed to unmarshal JSON: %w", err)
}

// If jsonData is nil or not a valid JSON, return empty string or error.
// Check if JSON data is valid
if jsonData == nil {
return "", fmt.Errorf("invalid json data")
}

// Split JSON path into parts (e.g., "data.items.0.name" -> ["data", "items", "0", "name"])
pathParts := strings.Split(jsonPath, ".")
current := jsonData

// Traverse the JSON structure using the path parts
for _, part := range pathParts {
if current == nil {
return "", fmt.Errorf("invalid json path, not found '%s'", part)
return "", fmt.Errorf("invalid json path: part '%s' not found in path '%s'", part, jsonPath)
}

switch value := current.(type) {
case map[string]interface{}:
// If the current value is a map, look for the key
if next, ok := value[part]; ok {
current = next
} else {
return "", fmt.Errorf("invalid json path, not found '%s'", part)
return "", fmt.Errorf("invalid json path: key '%s' not found in path '%s'", part, jsonPath)
}
case []interface{}:
// If the current value is an array, parse the index
index, err := strconv.Atoi(part)
if err != nil || index < 0 || index >= len(value) {
return "", fmt.Errorf("invalid json path, not found '%s'", part)
return "", fmt.Errorf("invalid json path: index '%s' is out of bounds or invalid in path '%s'", part, jsonPath)
}
current = value[index]
default:
return "", fmt.Errorf("invalid path '%s'", part)
// If the current value is neither a map nor an array, the path is invalid
return "", fmt.Errorf("invalid json path: unexpected type at part '%s' in path '%s'", part, jsonPath)
}
}

// Check if the final value is nil
if current == nil {
return "", fmt.Errorf("invalid path, value is nil '%s'", jsonPath)
return "", fmt.Errorf("invalid json path: value is nil at path '%s'", jsonPath)
}

// Convert the final value to a string
switch v := current.(type) {
case string:
return v, nil
case int, int64, float64, bool:
return fmt.Sprintf("%v", v), nil
default:
// For complex types (e.g., maps, arrays), marshal them back to JSON
jsonBytes, err := json.Marshal(v)
if err != nil {
return "", fmt.Errorf("failed to marshal JSON value at path '%s': %w", jsonPath, err)
}
return string(jsonBytes), nil
}
return fmt.Sprintf("%v", current), nil // Convert value to string (if possible)
}

func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
logID := uuid.New().String()

// Log the request with common fields
m.logRequest(zapcore.InfoLevel, "WAF evaluation started", r, zap.String("log_id", logID))

// Use the custom type as the key
ctx := context.WithValue(r.Context(), ContextKeyLogId("logID"), logID)
r = r.WithContext(ctx)
Expand Down
Loading

0 comments on commit 2c17527

Please sign in to comment.