diff --git a/internal/raw_templates.go b/internal/raw_templates.go index 5bf0569..1c61c26 100644 --- a/internal/raw_templates.go +++ b/internal/raw_templates.go @@ -1 +1,39 @@ package internal + +const CreateRawTemplate = `func (service *{{.ServiceName}}) Create(item {{.EntityType}}) ({{.EntityType}}, error) { + err := dal.{{.EntityType}}.Preload(field.Associations).Create(&item) + return item, err +}` + +const GetAllRawTemplate = `func (service *{{.ServiceName}}) GetAll() ([]*{{.EntityType}}, error) { + var {{.EntityPlural}} []*{{.EntityType}} + {{.EntityPlural}}, err := dal.{{.EntityType}}.Preload(field.Associations).Find() + return {{.EntityPlural}}, err +}` + +const GetByIdRawTemplate = `func (service *{{.ServiceName}}) GetById(id uint) (*{{.EntityType}}, error) { + item, err := dal.{{.EntityType}}.Preload(field.Associations).Where(dal.{{.EntityType}}.Id.Eq(id)).First() + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } else { + return nil, err + } + } + return item, nil +}` + +const UpdateRawTemplate = `func (service *{{.ServiceName}}) Update(item {{.EntityType}}) ({{.EntityType}}, error) { + err := dal.{{.EntityType}}.Preload(field.Associations).Save(&item) + return item, err +}` + +const DeleteRawTemplate = `func (service *{{.ServiceName}}) Delete(item {{.EntityType}}) ({{.EntityType}}, error) { + _, err := dal.{{.EntityType}}.Unscoped().Preload(field.Associations).Delete(&item) + return item, err +}` + +const CountRawTemplate = `func (service *{{.ServiceName}}) Count() (int64, error) { + amount, err := dal.{{.EntityType}}.Count() + return amount, err +}` diff --git a/internal/templates.go b/internal/templates.go index 8246f41..d5b1395 100644 --- a/internal/templates.go +++ b/internal/templates.go @@ -1,30 +1,31 @@ package internal -import ( - "strings" -) - -type CrudTemplateContext struct { - ServiceName string - EntityType string +type CrudTemplatesContext struct { + ServiceName string + EntityType string EntityPlural string } var ServiceImports = []string{ "app/internal/dal", "app/internal/models", - //"errors" + "errors", "gorm.io/gen/field", - //"gorm.io/gorm" + "gorm.io/gorm", } -var GetAllRawTemplate = `func (service *{{.ServiceName}}) GetAll() ([]*{{.EntityType}}, error) { - var {{.EntityPlurar}} []*{{.EntityType}} - {{.EntityPlural}}, err := dal.{{.EntityType}}.Preload(field.Associations).Find() - return {{.EntityPlural}}, err -}` +const CreateMethod = "Create" +const GetAllMethod = "GetAll" +const GetByIdMethod = "GetById" +const UpdateMethod = "Update" +const DeleteMethod = "Delete" +const CountMethod = "Count" - -func ToPlural(entityName string) string { - return strings.ToLower(entityName) + "s" +var RawTemplates = map[string]string{ + CreateMethod: CreateRawTemplate, + GetAllMethod: GetAllRawTemplate, + GetByIdMethod: GetByIdRawTemplate, + UpdateMethod: UpdateRawTemplate, + DeleteMethod: DeleteRawTemplate, + CountMethod: CountRawTemplate, } diff --git a/internal/utils.go b/internal/utils.go index 583d742..b7fd1f9 100644 --- a/internal/utils.go +++ b/internal/utils.go @@ -1,6 +1,7 @@ package internal import ( + "strings" "unicode" ) @@ -12,3 +13,7 @@ func CapitalizeFirst(s string) string { runes[0] = unicode.ToUpper(runes[0]) return string(runes) } + +func ToPlural(entityName string) string { + return strings.ToLower(entityName) + "s" +} diff --git a/internal/writer.go b/internal/writer.go index 6e13dfd..691966f 100644 --- a/internal/writer.go +++ b/internal/writer.go @@ -1,12 +1,14 @@ package internal import ( + "bytes" "errors" "fmt" "go/ast" "go/parser" "go/printer" "go/token" + "html/template" "log" "os" "path/filepath" @@ -181,8 +183,105 @@ func MaintainImports(fileSet *token.FileSet, file *ast.File) error { return nil } -func ImplementMethods(structName string, methodName string, template string, node ast.Node) { +func GenerateCrudMethodCode(methodName string, context CrudTemplatesContext) string { + templateCode, templateExists := RawTemplates[methodName] + if !templateExists { + panic(fmt.Sprintf("Template doesn't exist: %s \n", methodName)) + } + var buffer bytes.Buffer + templateCode = "package services\n\n" + templateCode + tmpl, err := template.New("CrudMethodTemplate").Parse(templateCode) + if err != nil { + panic(err) + } + if err := tmpl.Execute(&buffer, context); err != nil { + panic(err) + } + return buffer.String() +} +func MethodCodeToDeclaration(methodCode string) (ast.FuncDecl, error) { + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, "src.go", methodCode, parser.SkipObjectResolution) + if err != nil { + return ast.FuncDecl{}, err + } + + methodDecl := ast.FuncDecl{} + ast.Inspect(file, func(node ast.Node) bool { + funcDecl, ok := node.(*ast.FuncDecl) + if !ok { + return true + } + methodDecl = *funcDecl + return false + }) + return methodDecl, nil +} + +func ImplementCrudMethods(modelName string, serviceName string, file *ast.File, reimplement bool) error { + templateContext := CrudTemplatesContext{ + ServiceName: serviceName, + EntityType: modelName, + EntityPlural: ToPlural(modelName), + } + + for _, methodName := range []string{CreateMethod, GetAllMethod, GetByIdMethod, UpdateMethod, CountMethod} { + methodCode := GenerateCrudMethodCode(methodName, templateContext) + methodDecl, err := MethodCodeToDeclaration(methodCode) + if err != nil { + fmt.Println(methodDecl) + panic(err) + } + err = ImplementMethod(file, &methodDecl, reimplement) + if err != nil { + return err + } + } + + return nil +} + +func ImplementMethod(file *ast.File, methodDecl *ast.FuncDecl, reimplement bool) error { + var decls []ast.Decl + methodImplemented := false + + methodStructure := methodDecl.Recv.List[0].Names[0].Name + methodName := methodDecl.Name.Name + log.Printf("Standard method structure: %s\n", methodStructure) + log.Printf("Standard method name: %s\n", methodName) + + for _, decl := range file.Decls { + decls = append(decls, decl) + funcDecl, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + if len(funcDecl.Recv.List) > 0 && len(funcDecl.Recv.List[0].Names) > 0 { + fmt.Printf("Method structure: %s\n", funcDecl.Recv.List[0].Names[0].Name) + fmt.Printf("Method name: %s\n", funcDecl.Name.Name) + if funcDecl.Recv.List[0].Names[0].Name == methodStructure { + if funcDecl.Name != nil && funcDecl.Name.Name == methodName { + if methodImplemented { + err := fmt.Sprintf("`%s` method redeclarated for struct `%s`", methodName, methodStructure) + log.Println(err) + return errors.New(err) + } else { + methodImplemented = true + } + if reimplement { + decls = decls[:1] + } + } + } + } + } + + if reimplement || !methodImplemented { + file.Decls = append(decls, methodDecl) + } + + return nil } func CreateServiceFileIfNotExists(filePath string) error { @@ -208,6 +307,7 @@ func CreateServiceFileIfNotExists(filePath string) error { func ImplementService(mainPkgPath string, modelName string, reimplement bool) error { serviceRelativePath := fmt.Sprintf("services/%s.go", strings.ToLower(modelName)) filePath := filepath.Join(mainPkgPath, serviceRelativePath) + serviceName := modelName + "Service" err := CreateServiceFileIfNotExists(filePath) if err != nil { @@ -225,19 +325,25 @@ func ImplementService(mainPkgPath string, modelName string, reimplement bool) er if err != nil { return err } + ImplementModelAlias(modelName, serviceFile) ImplementServiceStruct(modelName, serviceFile, reimplement) + err = ImplementCrudMethods(modelName, serviceName, serviceFile, reimplement) + + if err != nil { + return err + } file, err := os.Create(filePath) if err != nil { - log.Fatalf("Error occured to open `%s` service file: %v", modelName, err) - return err + return errors.New(fmt.Sprintf("Error occured to open `%s` service file: %v", modelName, err)) } err = printer.Fprint(file, fset, serviceFile) if err != nil { - log.Fatalf("Error occurred to writing changes in `%s` service file: %v", modelName, err) - return err + return errors.New( + fmt.Sprintf("Error occurred to writing changes in `%s` service file: %v", modelName, err), + ) } defer file.Close()