diff --git a/html/extractor.go b/html/extractor.go
index ebd87f2..25a3e4c 100644
--- a/html/extractor.go
+++ b/html/extractor.go
@@ -2,10 +2,11 @@ package html
import (
"golang.org/x/net/html"
+ "regexp"
"strings"
)
-// HTMLExtractor represents an HTML-specific plain text extractor.
+// Extractor represents an HTML-specific plain text extractor.
type Extractor struct {
blockTags map[string]bool
}
@@ -34,7 +35,7 @@ func (e *Extractor) PlainText(input string) (*string, error) {
e.extractText(&plainText, doc)
output := plainText.String()
- output = strings.ReplaceAll(output, "\n ", "\n")
+ output = string(regexp.MustCompile("\n+\\s+").ReplaceAll([]byte(output), []byte("\n")))
return &output, nil
}
@@ -45,11 +46,7 @@ func (e *Extractor) extractText(plainText *strings.Builder, node *html.Node) {
text := strings.TrimSpace(node.Data)
if text != "" {
if plainText.Len() > 0 {
- if found := e.blockTags[node.Parent.DataAtom.String()]; found {
- plainText.WriteString("\n")
- } else {
- plainText.WriteString(" ")
- }
+ plainText.WriteString(" ")
}
plainText.WriteString(text)
}
@@ -62,4 +59,7 @@ func (e *Extractor) extractText(plainText *strings.Builder, node *html.Node) {
for child := node.FirstChild; child != nil; child = child.NextSibling {
e.extractText(plainText, child)
}
+ if found := e.blockTags[node.DataAtom.String()]; found {
+ plainText.WriteString("\n")
+ }
}
diff --git a/html/extractor_test.go b/html/extractor_test.go
index 98fb648..b23e07e 100644
--- a/html/extractor_test.go
+++ b/html/extractor_test.go
@@ -1,6 +1,7 @@
package html
import (
+ _ "embed"
"github.com/stretchr/testify/assert"
"testing"
)
@@ -12,10 +13,12 @@ func TestExtract(t *testing.T) {
expected string
}{
{`a
b`, "a\nb"},
- {`a
ab
c", "a b\nc"}, + {"a\n \nb", "a\nb"}, } for _, test := range tests { output, err := extractor.PlainText(test.input)