diff options
| -rw-r--r-- | markdown/ast_code.go | 24 | ||||
| -rw-r--r-- | markdown/ast_code_test.go | 33 |
2 files changed, 44 insertions, 13 deletions
diff --git a/markdown/ast_code.go b/markdown/ast_code.go index 6b949f2..4fa962a 100644 --- a/markdown/ast_code.go +++ b/markdown/ast_code.go @@ -42,28 +42,26 @@ func (a *astCode) Eval(_ *Option) (template.HTML, *ParseError) { func code(lxs *lexers) (*astCode, *ParseError) { tree := new(astCode) - current := lxs.Current().Value - if len(current) == 3 { + codeTag := lxs.Current().Value + if len(codeTag) == 3 { tree.codeType = codeMultiLine - } else if len(current) == 1 { + } else if len(codeTag) == 1 { tree.codeType = codeOneLine } else { return nil, &ParseError{lxs: *lxs, internal: ErrInvalidCodeFormat} } started := false - for lxs.Next() && lxs.Current().Value != current { - if lxs.Current().Type == lexerBreak { + for lxs.Next() && lxs.Current().Value != codeTag { + isBreak := lxs.Current().Type == lexerBreak + if started || (tree.codeType == codeOneLine && !isBreak) { + tree.content += lxs.Current().Value + } else if !isBreak { + tree.before += lxs.Current().Value + } else { if tree.codeType == codeOneLine { return nil, &ParseError{lxs: *lxs, internal: ErrInvalidCodeFormat} } - if !started { - started = true - } - } - if started || tree.codeType == codeOneLine { - tree.content += lxs.Current().Value - } else { - tree.before += lxs.Current().Value + started = true } } return tree, nil diff --git a/markdown/ast_code_test.go b/markdown/ast_code_test.go new file mode 100644 index 0000000..b116544 --- /dev/null +++ b/markdown/ast_code_test.go @@ -0,0 +1,33 @@ +package markdown + +import "testing" + +func TestCode(t *testing.T) { + got, err := Parse("`mono`", nil) + if err != nil { + t.Fatal(err) + } + if string(got) != `<p><code>mono</code></p>` { + t.Errorf("invalid value, got %s", got) + } + + got, err = Parse("bonjour `code` !", nil) + if err != nil { + t.Fatal(err) + } + if string(got) != `<p>bonjour <code>code</code> !</p>` { + t.Errorf("invalid value, got %s", got) + } + + got, err = Parse( + "```\n"+"raw\nhehe"+"```", + nil, + ) + if err != nil { + t.Fatal(err) + } + if string(got) != `<pre><code>raw +hehe</code></pre>` { + t.Errorf("invalid value, got %s", got) + } +} |
