diff --git a/html/extractor.go b/html/extractor.go index ebd87f2..25a3e4c 100644 --- a/html/extractor.go +++ b/html/extractor.go @@ -2,10 +2,11 @@ package html import ( "golang.org/x/net/html" + "regexp" "strings" ) -// HTMLExtractor represents an HTML-specific plain text extractor. +// Extractor represents an HTML-specific plain text extractor. type Extractor struct { blockTags map[string]bool } @@ -34,7 +35,7 @@ func (e *Extractor) PlainText(input string) (*string, error) { e.extractText(&plainText, doc) output := plainText.String() - output = strings.ReplaceAll(output, "\n ", "\n") + output = string(regexp.MustCompile("\n+\\s+").ReplaceAll([]byte(output), []byte("\n"))) return &output, nil } @@ -45,11 +46,7 @@ func (e *Extractor) extractText(plainText *strings.Builder, node *html.Node) { text := strings.TrimSpace(node.Data) if text != "" { if plainText.Len() > 0 { - if found := e.blockTags[node.Parent.DataAtom.String()]; found { - plainText.WriteString("\n") - } else { - plainText.WriteString(" ") - } + plainText.WriteString(" ") } plainText.WriteString(text) } @@ -62,4 +59,7 @@ func (e *Extractor) extractText(plainText *strings.Builder, node *html.Node) { for child := node.FirstChild; child != nil; child = child.NextSibling { e.extractText(plainText, child) } + if found := e.blockTags[node.DataAtom.String()]; found { + plainText.WriteString("\n") + } } diff --git a/html/extractor_test.go b/html/extractor_test.go index 98fb648..b23e07e 100644 --- a/html/extractor_test.go +++ b/html/extractor_test.go @@ -1,6 +1,7 @@ package html import ( + _ "embed" "github.com/stretchr/testify/assert" "testing" ) @@ -12,10 +13,12 @@ func TestExtract(t *testing.T) { expected string }{ {`a
b`, "a\nb"}, - {`a

b

`, "a\n\nb"}, + {`a

b

`, "a\nb\n"}, {`link`, "link"}, - {`
This is a link
`, "This is a link"}, - {"

Heading 1

Heading 2

", "Heading 1\nHeading 2\nItem 1\nItem 2"}, + {`
This is a link
`, "This is a link\n"}, + {"

Heading 1

Heading 2

", "Heading 1\nHeading 2\nItem 1\nItem 2\n"}, + {"

ab

c", "a b\nc"}, + {"a\n \nb", "a\nb"}, } for _, test := range tests { output, err := extractor.PlainText(test.input)