lox-go/ast/gen/transpile.go

146 lines
3 KiB
Go

package main
import (
"bytes"
"fmt"
"go/format"
"text/template"
)
type defTmplData struct {
DefName string
Name string
Fields []string
}
type transpiler struct {
Name string
DefinitionNames []string
Definitions []string
astDefsTmpl *template.Template
defTmpl *template.Template
}
var astDefinitionsTemplate = `
{{ $name := .Name }}
package ast
// THIS FILE WAS AUTOMATICALLY GENERATED, DO NOT MANUALLY EDIT
import "git.red-panda.pet/pandaware/lox-go/lexer"
type {{ .Name }}Visitor interface {
{{ range .DefinitionNames }}Visit{{ . }}{{ $name }}(v *{{ . }}{{ $name }}) any{{ "\n" }}{{ end }}
}
type {{ .Name }} interface {
Accept(v {{ .Name }}Visitor) any
}
{{ range .Definitions }}
{{ . }}
{{ end }}`
var definitionTemplate = `
type {{ .DefName }}{{ .Name }} struct {
{{ range .Fields }}{{ . }}{{ "\n" }}{{ end }}
}
func (n *{{ .DefName }}{{ .Name }}) Accept(v {{ .Name }}Visitor) any {
return v.Visit{{ .DefName }}{{ .Name }}(n)
}
var _ {{ .Name }} = new({{ .DefName }}{{ .Name }})`
// visitASTDefinitionsNode implements visitor.
func (t *transpiler) visitASTDefinitionsNode(a *astDefinitionsNode) (string, error) {
t.Name = a.name
for _, defNode := range a.definitions {
def, err := defNode.accept(t)
if err != nil {
return "", err
}
t.Definitions = append(t.Definitions, def)
}
buf := &bytes.Buffer{}
err := t.astDefsTmpl.Execute(buf, t)
return buf.String(), err
}
// visitDefinition implements visitor.
func (t *transpiler) visitDefinition(d *definitionNode) (string, error) {
name, err := d.identifier.accept(t)
if err != nil {
return "", err
}
t.DefinitionNames = append(t.DefinitionNames, name)
fields := make([]string, len(d.fields))
for _, fieldNode := range d.fields {
field, err := fieldNode.accept(t)
if err != nil {
return "", err
}
fields = append(fields, field)
}
buf := &bytes.Buffer{}
err = t.defTmpl.Execute(buf, defTmplData{
DefName: name,
Name: t.Name,
Fields: fields,
})
return buf.String(), err
}
// visitField implements visitor.
func (t *transpiler) visitField(g *fieldNode) (string, error) {
left, err := g.left.accept(t)
if err != nil {
return "", err
}
right, err := g.right.accept(t)
if err != nil {
return "", err
}
return fmt.Sprintf("\t%s\t%s", left, right), nil
}
// visitIdentifier implements visitor.
func (t *transpiler) visitIdentifier(i *identifierNode) (string, error) {
return i.value, nil
}
// visitName implements visitor.
func (t *transpiler) visitName(n *nameNode) (string, error) {
return n.value, nil
}
var _ visitor = new(transpiler)
func transpile(n node) (string, []byte, error) {
var err error
t := new(transpiler)
t.Definitions = []string{}
t.DefinitionNames = []string{}
t.astDefsTmpl, err = template.New("").Parse(astDefinitionsTemplate)
t.defTmpl, err = template.New("").Parse(definitionTemplate)
str, err := n.accept(t)
if err != nil {
return "", nil, err
}
bs, err := format.Source([]byte(str))
return t.Name, bs, err
}