wrote dsl for ast boiler plate generation

This commit is contained in:
basil 2025-06-08 21:18:17 -04:00
parent b244f7e3b2
commit e0dd8ff9d5
Signed by: basil
SSH key fingerprint: SHA256:y04xIFL/yqNaG9ae9Vl95vELtHfApGAIoOGLeVLP/fE
16 changed files with 915 additions and 60 deletions

55
ast/gen/debug.go Normal file
View file

@ -0,0 +1,55 @@
package main
type debugVisitor struct{}
func debug(node node) string {
d := new(debugVisitor)
v, _ := node.accept(d)
return v
}
// visitASTDefinitionsNode implements visitor.
func (d *debugVisitor) visitASTDefinitionsNode(a *astDefinitionsNode) (string, error) {
v := "#" + a.name + "\n\n"
for _, defNode := range a.definitions {
def, _ := defNode.accept(d)
v += def + "\n\n"
}
return v, nil
}
// visitDefinition implements visitor.
func (d *debugVisitor) visitDefinition(def *definitionNode) (string, error) {
id, _ := def.identifier.accept(d)
v := id + " [\n"
for _, fNode := range def.fields {
field, _ := fNode.accept(d)
v += "\t" + field + "\n"
}
v += "]\n\n"
return v, nil
}
// visitField implements visitor.
func (d *debugVisitor) visitField(g *fieldNode) (string, error) {
left, _ := g.left.accept(d)
right, _ := g.right.accept(d)
return left + " = " + right + ";", nil
}
// visitIdentifier implements visitor.
func (d *debugVisitor) visitIdentifier(i *identifierNode) (string, error) {
return i.value, nil
}
// visitName implements visitor.
func (d *debugVisitor) visitName(n *nameNode) (string, error) {
return n.value, nil
}
var _ visitor = new(debugVisitor)

52
ast/gen/generate.go Normal file
View file

@ -0,0 +1,52 @@
package main
import (
"flag"
"fmt"
"os"
"strings"
)
var fileName string
func init() {
flag.StringVar(&fileName, "f", "", "ast file to generate from")
}
func main() {
flag.Parse()
bs, err := os.ReadFile(fileName)
tokens, errs := lex(string(bs))
if len(errs) > 0 {
for _, err := range errs {
fmt.Printf("%s\n", err.Error())
}
os.Exit(1)
}
ast, err := parse(tokens)
if err != nil {
fmt.Printf("%s\n", err.Error())
os.Exit(1)
}
name, output, err := transpile(ast)
if err != nil {
fmt.Printf("%s\n", err.Error())
os.Exit(1)
}
f, err := os.Create(strings.ToLower(name) + ".go")
if err != nil {
fmt.Printf("%s\n", err.Error())
os.Exit(1)
}
_, err = f.Write(output)
if err != nil {
fmt.Printf("%s\n", err.Error())
os.Exit(1)
}
}

180
ast/gen/lex.go Normal file
View file

@ -0,0 +1,180 @@
package main
import (
"errors"
"fmt"
)
type tokenType int
const (
tokenTypeIdentifier tokenType = iota
tokenTypeRightBracket
tokenTypeLeftBracket
tokenTypeEqual
tokenTypeName
tokenTypeSemicolon
tokenTypeEOF
)
type token struct {
Type tokenType
Lexeme string
Line int
}
func isUpperAlpha(r rune) bool {
return r >= 'A' && r <= 'Z'
}
func isAlpha(r rune) bool {
return isUpperAlpha(r) || (r >= 'a' && r <= 'z')
}
func isNumeric(r rune) bool {
return r >= '0' && r <= '9'
}
func isAlphaNumeric(r rune) bool {
return isAlpha(r) || isNumeric(r)
}
func isGoIdentifier(r rune) bool {
return isAlphaNumeric(r) || r == '_'
}
func isIdentifier(r rune) bool {
return isGoIdentifier(r) || r == '.' || r == '*'
}
type lexer struct {
source []rune
current int
start int
line int
tokens []*token
}
func (l *lexer) addToken(t tokenType) {
l.tokens = append(l.tokens, &token{
Type: t,
Lexeme: string(l.source[l.start:l.current]),
Line: l.line,
})
}
func (l *lexer) peek() rune {
if l.isAtEnd() {
return rune(0)
}
return l.source[l.current]
}
func (l *lexer) peekNext() rune {
if l.current+1 >= len(l.source) {
return rune(0)
}
return l.source[l.current+1]
}
func (l *lexer) advance() rune {
r := l.source[l.current]
l.current += 1
return r
}
func (l *lexer) isAtEnd() bool {
return l.current >= len(l.source)
}
func (l *lexer) scanToken() error {
r := l.advance()
switch r {
case '[':
l.addToken(tokenTypeLeftBracket)
case ']':
l.addToken(tokenTypeRightBracket)
case '=':
l.addToken(tokenTypeEqual)
case ';':
l.addToken(tokenTypeSemicolon)
case ' ', '\r', '\t':
break
case '#':
next := l.peek()
if isUpperAlpha(next) {
l.name()
return nil
}
return errors.New(fmt.Sprintf("names must have an uppercase alphabetical first character, found '%s'", string(next)))
case '\n':
l.line += 1
default:
if isIdentifier(r) {
l.identifier()
return nil
}
return errors.New(fmt.Sprintf("unexpected character '%s' at line %d", string(r), l.line))
}
return nil
}
func (l *lexer) name() {
for isGoIdentifier(l.peek()) && l.peek() != '\n' {
l.advance()
}
text := l.source[l.start+1 : l.current]
l.advance()
l.tokens = append(l.tokens, &token{
Type: tokenTypeName,
Lexeme: string(text),
Line: l.line,
})
}
func (l *lexer) identifier() {
for isIdentifier(l.peek()) {
l.advance()
}
text := l.source[l.start:l.current]
l.tokens = append(l.tokens, &token{
Type: tokenTypeIdentifier,
Lexeme: string(text),
Line: l.line,
})
}
func (l *lexer) scanTokens() ([]*token, []error) {
errs := []error{}
for !l.isAtEnd() {
l.start = l.current
err := l.scanToken()
if err != nil {
errs = append(errs, err)
}
}
l.addToken(tokenTypeEOF)
return l.tokens, errs
}
func lex(source string) ([]*token, []error) {
l := new(lexer)
l.source = []rune(source)
l.current = 0
l.start = 0
l.line = 1
l.tokens = []*token{}
tokens, errs := l.scanTokens()
return tokens, errs
}

229
ast/gen/parser.go Normal file
View file

@ -0,0 +1,229 @@
package main
import (
"errors"
"fmt"
)
type visitor interface {
visitASTDefinitionsNode(a *astDefinitionsNode) (string, error)
visitName(n *nameNode) (string, error)
visitIdentifier(i *identifierNode) (string, error)
visitField(g *fieldNode) (string, error)
visitDefinition(d *definitionNode) (string, error)
}
type node interface {
accept(v visitor) (string, error)
}
type astDefinitionsNode struct {
name string
definitions []node
}
// accept implements node.
func (a *astDefinitionsNode) accept(v visitor) (string, error) {
return v.visitASTDefinitionsNode(a)
}
var _ node = new(astDefinitionsNode)
type nameNode struct {
value string
}
// accept implements node.
func (n *nameNode) accept(v visitor) (string, error) {
return v.visitName(n)
}
var _ node = new(nameNode)
type identifierNode struct {
value string
}
// accept implements node.
func (i *identifierNode) accept(v visitor) (string, error) {
return v.visitIdentifier(i)
}
var _ node = new(identifierNode)
type fieldNode struct {
left node
right node
}
// accept implements node.
func (f *fieldNode) accept(v visitor) (string, error) {
return v.visitField(f)
}
var _ node = new(fieldNode)
type definitionNode struct {
identifier node
fields []node
}
// accept implements node.
func (d *definitionNode) accept(v visitor) (string, error) {
return v.visitDefinition(d)
}
var _ node = new(definitionNode)
type parser struct {
tokens []*token
current int
}
func (p *parser) peek() *token {
return p.tokens[p.current]
}
func (p *parser) isAtEnd() bool {
return p.peek().Type == tokenTypeEOF
}
func (p *parser) check(t tokenType) bool {
if p.isAtEnd() {
return false
}
return p.peek().Type == t
}
func (p *parser) match(types ...tokenType) bool {
for _, t := range types {
if p.check(t) {
p.advance()
return true
}
}
return false
}
func (p *parser) advance() *token {
t := p.tokens[p.current]
p.current += 1
return t
}
func (p *parser) consume(t tokenType, msg string) (*token, error) {
if p.check(t) {
return p.advance(), nil
}
return nil, errors.New(msg)
}
func (p *parser) previous() *token {
return p.tokens[p.current-1]
}
func (p *parser) astDefinitions() (node, error) {
if p.match(tokenTypeName) {
name := p.previous()
defs := []node{}
for !p.isAtEnd() {
def, err := p.definition()
if err != nil {
return nil, err
}
defs = append(defs, def)
}
return &astDefinitionsNode{
name: name.Lexeme,
definitions: defs,
}, nil
}
return nil, errors.New("expected name definition at start of file")
}
// definition -> identifier "[" field+ "]"
func (p *parser) definition() (node, error) {
id, err := p.identifier()
if err != nil {
return nil, err
}
if p.match(tokenTypeLeftBracket) {
fields := []node{}
for !p.check(tokenTypeRightBracket) && !p.isAtEnd() {
f, err := p.field()
if err != nil {
return nil, err
}
fields = append(fields, f)
}
if p.isAtEnd() {
return nil, errors.New(fmt.Sprintf("expected ']' after field definitions in '%s', got EOF", debug(id)))
}
_, err := p.consume(tokenTypeRightBracket, fmt.Sprintf("expected ']' after field definitions in '%s', got EOF", debug(id)))
if err != nil {
return nil, err
}
return &definitionNode{
identifier: id,
fields: fields,
}, nil
}
return nil, errors.New(fmt.Sprintf("expected '[' after identifier '%s'", debug(id)))
}
// field -> identifier "=" identifier ";"
func (p *parser) field() (node, error) {
left, err := p.identifier()
if err != nil {
return nil, err
}
if p.match(tokenTypeEqual) {
right, err := p.identifier()
if err != nil {
return nil, err
}
if p.match(tokenTypeSemicolon) {
return &fieldNode{
left: left,
right: right,
}, nil
}
return nil, errors.New(fmt.Sprintf("expected ';' at end of field '%s'", debug(left)))
}
return nil, errors.New(fmt.Sprintf("expected '=' after identifier '%s'", debug(left)))
}
func (p *parser) identifier() (node, error) {
if p.match(tokenTypeIdentifier) {
return &identifierNode{
value: p.previous().Lexeme,
}, nil
}
return nil, errors.New("expected identifier")
}
func parse(tokens []*token) (node, error) {
p := new(parser)
p.tokens = tokens
p.current = 0
return p.astDefinitions()
}

146
ast/gen/transpile.go Normal file
View file

@ -0,0 +1,146 @@
package main
import (
"bytes"
"fmt"
"go/format"
"text/template"
)
type defTmplData struct {
DefName string
Name string
Fields []string
}
type transpiler struct {
Name string
DefinitionNames []string
Definitions []string
astDefsTmpl *template.Template
defTmpl *template.Template
}
var astDefinitionsTemplate = `
{{ $name := .Name }}
package ast
// THIS FILE WAS AUTOMATICALLY GENERATED, DO NOT MANUALLY EDIT
import "git.red-panda.pet/pandaware/lox-go/lexer"
type {{ .Name }}Visitor interface {
{{ range .DefinitionNames }}Visit{{ . }}{{ $name }}(v *{{ . }}{{ $name }}) any{{ "\n" }}{{ end }}
}
type {{ .Name }} interface {
Accept(v {{ .Name }}Visitor) any
}
{{ range .Definitions }}
{{ . }}
{{ end }}`
var definitionTemplate = `
type {{ .DefName }}{{ .Name }} struct {
{{ range .Fields }}{{ . }}{{ "\n" }}{{ end }}
}
func (n *{{ .DefName }}{{ .Name }}) Accept(v {{ .Name }}Visitor) any {
return v.Visit{{ .DefName }}{{ .Name }}(n)
}
var _ {{ .Name }} = new({{ .DefName }}{{ .Name }})`
// visitASTDefinitionsNode implements visitor.
func (t *transpiler) visitASTDefinitionsNode(a *astDefinitionsNode) (string, error) {
t.Name = a.name
for _, defNode := range a.definitions {
def, err := defNode.accept(t)
if err != nil {
return "", err
}
t.Definitions = append(t.Definitions, def)
}
buf := &bytes.Buffer{}
err := t.astDefsTmpl.Execute(buf, t)
return buf.String(), err
}
// visitDefinition implements visitor.
func (t *transpiler) visitDefinition(d *definitionNode) (string, error) {
name, err := d.identifier.accept(t)
if err != nil {
return "", err
}
t.DefinitionNames = append(t.DefinitionNames, name)
fields := make([]string, len(d.fields))
for _, fieldNode := range d.fields {
field, err := fieldNode.accept(t)
if err != nil {
return "", err
}
fields = append(fields, field)
}
buf := &bytes.Buffer{}
err = t.defTmpl.Execute(buf, defTmplData{
DefName: name,
Name: t.Name,
Fields: fields,
})
return buf.String(), err
}
// visitField implements visitor.
func (t *transpiler) visitField(g *fieldNode) (string, error) {
left, err := g.left.accept(t)
if err != nil {
return "", err
}
right, err := g.right.accept(t)
if err != nil {
return "", err
}
return fmt.Sprintf("\t%s\t%s\n", left, right), nil
}
// visitIdentifier implements visitor.
func (t *transpiler) visitIdentifier(i *identifierNode) (string, error) {
return i.value, nil
}
// visitName implements visitor.
func (t *transpiler) visitName(n *nameNode) (string, error) {
return n.value, nil
}
var _ visitor = new(transpiler)
func transpile(n node) (string, []byte, error) {
var err error
t := new(transpiler)
t.Definitions = []string{}
t.DefinitionNames = []string{}
t.astDefsTmpl, err = template.New("").Parse(astDefinitionsTemplate)
t.defTmpl, err = template.New("").Parse(definitionTemplate)
str, err := n.accept(t)
if err != nil {
return "", nil, err
}
bs, err := format.Source([]byte(str))
return t.Name, bs, err
}