package main import ( "bytes" "fmt" "go/format" "text/template" ) type defTmplData struct { DefName string Name string Fields []string } type transpiler struct { Name string DefinitionNames []string Definitions []string astDefsTmpl *template.Template defTmpl *template.Template } var astDefinitionsTemplate = ` {{ $name := .Name }} package ast // THIS FILE WAS AUTOMATICALLY GENERATED, DO NOT MANUALLY EDIT import "git.red-panda.pet/pandaware/lox-go/lexer" type {{ .Name }}Visitor interface { {{ range .DefinitionNames }}Visit{{ . }}{{ $name }}(v *{{ . }}{{ $name }}) any{{ "\n" }}{{ end }} } type {{ .Name }} interface { Accept(v {{ .Name }}Visitor) any } {{ range .Definitions }} {{ . }} {{ end }}` var definitionTemplate = ` type {{ .DefName }}{{ .Name }} struct { {{ range .Fields }}{{ . }}{{ "\n" }}{{ end }} } func (n *{{ .DefName }}{{ .Name }}) Accept(v {{ .Name }}Visitor) any { return v.Visit{{ .DefName }}{{ .Name }}(n) } var _ {{ .Name }} = new({{ .DefName }}{{ .Name }})` // visitASTDefinitionsNode implements visitor. func (t *transpiler) visitASTDefinitionsNode(a *astDefinitionsNode) (string, error) { t.Name = a.name for _, defNode := range a.definitions { def, err := defNode.accept(t) if err != nil { return "", err } t.Definitions = append(t.Definitions, def) } buf := &bytes.Buffer{} err := t.astDefsTmpl.Execute(buf, t) return buf.String(), err } // visitDefinition implements visitor. func (t *transpiler) visitDefinition(d *definitionNode) (string, error) { name, err := d.identifier.accept(t) if err != nil { return "", err } t.DefinitionNames = append(t.DefinitionNames, name) fields := make([]string, len(d.fields)) for _, fieldNode := range d.fields { field, err := fieldNode.accept(t) if err != nil { return "", err } fields = append(fields, field) } buf := &bytes.Buffer{} err = t.defTmpl.Execute(buf, defTmplData{ DefName: name, Name: t.Name, Fields: fields, }) return buf.String(), err } // visitField implements visitor. func (t *transpiler) visitField(g *fieldNode) (string, error) { left, err := g.left.accept(t) if err != nil { return "", err } right, err := g.right.accept(t) if err != nil { return "", err } return fmt.Sprintf("\t%s\t%s", left, right), nil } // visitIdentifier implements visitor. func (t *transpiler) visitIdentifier(i *identifierNode) (string, error) { return i.value, nil } // visitName implements visitor. func (t *transpiler) visitName(n *nameNode) (string, error) { return n.value, nil } var _ visitor = new(transpiler) func transpile(n node) (string, []byte, error) { var err error t := new(transpiler) t.Definitions = []string{} t.DefinitionNames = []string{} t.astDefsTmpl, err = template.New("").Parse(astDefinitionsTemplate) t.defTmpl, err = template.New("").Parse(definitionTemplate) str, err := n.accept(t) if err != nil { return "", nil, err } bs, err := format.Source([]byte(str)) return t.Name, bs, err }