diff --git a/grpcv1.go b/grpcv1.go index 1c891b94..b1e5c1ee 100644 --- a/grpcv1.go +++ b/grpcv1.go @@ -3,6 +3,7 @@ package headscale import ( "context" + "fmt" "strings" "time" @@ -195,13 +196,11 @@ func (api headscaleV1APIServer) SetTags( } for _, tag := range request.GetTags() { - if strings.Index(tag, "tag:") != 0 { + err := validateTag(tag) + if err != nil { return &v1.SetTagsResponse{ Machine: nil, - }, status.Error( - codes.InvalidArgument, - "Invalid tag detected. Each tag must start with the string 'tag:'", - ) + }, status.Error(codes.InvalidArgument, err.Error()) } } @@ -220,6 +219,19 @@ func (api headscaleV1APIServer) SetTags( return &v1.SetTagsResponse{Machine: machine.toProto()}, nil } +func validateTag(tag string) error { + if strings.Index(tag, "tag:") != 0 { + return fmt.Errorf("tag must start with the string 'tag:'") + } + if strings.ToLower(tag) != tag { + return fmt.Errorf("tag should be lowercase") + } + if len(strings.Fields(tag)) > 1 { + return fmt.Errorf("tag should not contains space") + } + return nil +} + func (api headscaleV1APIServer) DeleteMachine( ctx context.Context, request *v1.DeleteMachineRequest, diff --git a/grpcv1_test.go b/grpcv1_test.go new file mode 100644 index 00000000..e48ae1ef --- /dev/null +++ b/grpcv1_test.go @@ -0,0 +1,42 @@ +package headscale + +import "testing" + +func Test_validateTag(t *testing.T) { + type args struct { + tag string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "valid tag", + args: args{tag: "tag:test"}, + wantErr: false, + }, + { + name: "tag without tag prefix", + args: args{tag: "test"}, + wantErr: true, + }, + { + name: "uppercase tag", + args: args{tag: "tag:tEST"}, + wantErr: true, + }, + { + name: "tag that contains space", + args: args{tag: "tag:this is a spaced tag"}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := validateTag(tt.args.tag); (err != nil) != tt.wantErr { + t.Errorf("validateTag() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/integration_cli_test.go b/integration_cli_test.go index f9ff5ec0..2f58e71d 100644 --- a/integration_cli_test.go +++ b/integration_cli_test.go @@ -625,7 +625,7 @@ func (s *IntegrationCLITestSuite) TestNodeTagCommand() { var errorOutput errOutput err = json.Unmarshal([]byte(wrongTagResult), &errorOutput) assert.Nil(s.T(), err) - assert.Contains(s.T(), errorOutput.Error, "Invalid tag detected") + assert.Contains(s.T(), errorOutput.Error, "tag must start with the string 'tag:'") // Test list all nodes after added seconds listAllResult, err := ExecuteCommand(