diff --git a/.circleci/config.yml b/.circleci/config.yml index 0156cdb..e73fa69 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,8 +1,27 @@ version: 2.1 jobs: + test: + docker: + - image: cimg/go:1.22 + steps: + - checkout + - restore_cache: + keys: + - go-mod-v1-{{ checksum "go.sum" }} + - go-mod-v1 + - run: + name: Install Dependencies + command: go get ./... + - save_cache: + key: go-mod-v1-{{ checksum "go.sum" }} + paths: + - "/go/pkg/mod" + - run: + name: Run tests + command: go test ./... build: - docker: + docker: - image: cimg/go:1.22 steps: - run: @@ -20,7 +39,7 @@ jobs: key: go-mod-v1-{{ checksum "go.sum" }} paths: - "/go/pkg/mod" - - run: + - run: name: Build command: go build -C cmd/gowake/ -v -ldflags="-w -s" - run: @@ -28,19 +47,19 @@ jobs: command: upx ./cmd/gowake/gowake - persist_to_workspace: root: ~/project - paths: - - cmd/gowake/gowake + paths: + - cmd/gowake/gowake release: - docker: - - image: cimg/base:current + docker: + - image: cimg/base:current steps: - attach_workspace: at: ./ - - run: + - run: name: Package command: tar cvfz ${CIRCLE_PROJECT_REPONAME}_${CIRCLE_TAG}_linux_amd64.tar.gz cmd/${CIRCLE_PROJECT_REPONAME}/${CIRCLE_PROJECT_REPONAME} - run: - name: Release + name: Release command: | curl -v \ -X POST \ @@ -56,7 +75,7 @@ jobs: -H "X-GitHub-Api-Version: 2022-11-28" \ https://api.github.com/repos/jedrw/${CIRCLE_PROJECT_REPONAME}/releases/tags/${CIRCLE_TAG} \ | jq '.id') - + curl -v \ -X POST \ -H "Accept: application/vnd.github+json" \ @@ -67,8 +86,12 @@ jobs: --data-binary @${CIRCLE_PROJECT_REPONAME}_${CIRCLE_TAG}_linux_amd64.tar.gz workflows: - build_and_release: + test_build_and_release: jobs: + - test: + filters: + tags: + only: /v\d+\.\d+\.\d+/ - build: filters: tags: @@ -77,10 +100,10 @@ workflows: context: - github requires: + - test - build filters: branches: ignore: /.*/ tags: only: /v\d+\.\d+\.\d+/ - diff --git a/cmd/gowake/gowake.go b/cmd/gowake/gowake.go index fabdb2e..879da3d 100644 --- a/cmd/gowake/gowake.go +++ b/cmd/gowake/gowake.go @@ -1,59 +1,48 @@ package main import ( + "errors" "fmt" - "regexp" + "net" - "github.com/jedrw/gowake/cmd/listen" + "github.com/jedrw/gowake/cmd/gowake/listen" "github.com/jedrw/gowake/pkg/magicpacket" "github.com/spf13/cobra" ) -var gowakeCmd = &cobra.Command{ - Use: "gowake [macaddress]", - Short: "Send a magic packet in order wake a compliant machine", - Args: cobra.ExactArgs(1), - Run: func(cmd *cobra.Command, args []string) { - // Get port - port, _ := cmd.Flags().GetInt("port") +var port int +var ip string - // Get IP +var gowakeCmd = &cobra.Command{ + Use: "gowake [macaddress]", + Short: "Send a magic packet", + Args: cobra.ExactArgs(1), + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { ip, _ := cmd.Flags().GetString("ip") - - is_ip, _ := regexp.MatchString(`^((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.?\b){4}$`, ip) - if !is_ip { - fmt.Println("Please provide a valid IP address") - return + if net.ParseIP(ip) == nil { + return errors.New("got invalid IP") } - // Check for mac address - if len(args) < 1 { - fmt.Println("Please provide a MAC address") - return - } - - // Build packet - mp, err := magicpacket.New(args[0]) + port, _ := cmd.Flags().GetInt("port") + magicPacket, err := magicpacket.New(args[0]) if err != nil { - fmt.Println(err.Error()) - return + return err } - // Send packet - err = magicpacket.Send(mp, ip, port) + err = magicpacket.Send(magicPacket, ip, port) if err != nil { - fmt.Println(err.Error()) + return err } - fmt.Printf("Sent magic packet to %v\n", args[0]) + fmt.Printf("Sent magic packet %s to %s:%d\n", args[0], ip, port) + return nil }, } func init() { - var port int - var ip string gowakeCmd.AddCommand(listen.ListenCmd) - gowakeCmd.PersistentFlags().IntVarP(&port, "port", "p", 9, "Port to send or listen for magic packet") + gowakeCmd.Flags().IntVarP(&port, "port", "p", 9, "Port to send magic packet to") gowakeCmd.Flags().StringVarP(&ip, "ip", "i", "255.255.255.255", "Destination (IP or broadcast address) for the magic packet") gowakeCmd.PersistentFlags().BoolP("help", "h", false, "Print help for command") cobra.EnableCommandSorting = false diff --git a/cmd/gowake/listen/listen.go b/cmd/gowake/listen/listen.go new file mode 100644 index 0000000..6307ffe --- /dev/null +++ b/cmd/gowake/listen/listen.go @@ -0,0 +1,51 @@ +package listen + +import ( + "errors" + "fmt" + "syscall" + + "github.com/jedrw/gowake/pkg/magicpacket" + "github.com/spf13/cobra" +) + +var port int +var ip string + +var ListenCmd = &cobra.Command{ + Use: "listen", + Short: "Listen for a magic packet", + RunE: func(cmd *cobra.Command, args []string) error { + ip, _ := cmd.Flags().GetString("ip") + port, _ := cmd.Flags().GetInt("port") + cont, _ := cmd.Flags().GetBool("continuous") + fmt.Printf("Listening for magic packets on %s:%d\n", ip, port) + for { + remote, mac, err := magicpacket.Listen(ip, port) + if err != nil { + var errno syscall.Errno + if errors.As(err, &errno) { + if errno == syscall.EACCES { + return fmt.Errorf("%w: please run as elevated user", err) + } + } else { + return err + } + } + + fmt.Printf("%s from %s\n", mac, remote.String()) + if !cont { + break + } + } + + return nil + }, +} + +func init() { + var continuous bool + ListenCmd.Flags().IntVarP(&port, "port", "p", 9, "Port to listen for magic packets on") + ListenCmd.Flags().StringVarP(&ip, "ip", "i", "0.0.0.0", "Address to listen for magic packets on") + ListenCmd.Flags().BoolVarP(&continuous, "continuous", "c", false, "Listen continuously for magic packets") +} diff --git a/cmd/listen/listen.go b/cmd/listen/listen.go deleted file mode 100644 index c078eb1..0000000 --- a/cmd/listen/listen.go +++ /dev/null @@ -1,39 +0,0 @@ -package listen - -import ( - "fmt" - - "github.com/jedrw/gowake/pkg/magicpacket" - "github.com/spf13/cobra" -) - -var ListenCmd = &cobra.Command{ - Use: "listen", - Short: "Listen for a magic packet", - Run: func(cmd *cobra.Command, args []string) { - port, _ := cmd.Flags().GetInt("port") - cont, _ := cmd.Flags().GetBool("continuous") - fmt.Printf("Listening for magic packets on port %d:\n", port) - for { - remote, mac, err := magicpacket.Listen(port) - if err != nil { - if err.Error() == fmt.Sprintf("listen udp 0.0.0.0:%d: bind: permission denied", port) { - fmt.Println("Please run as elevated user") - return - } else { - fmt.Println(err) - return - } - } - fmt.Printf("%v from %v\n", mac, remote.String()) - if !cont { - break - } - } - }, -} - -func init() { - var continuous bool - ListenCmd.Flags().BoolVarP(&continuous, "continuous", "c", false, "Listen continuously for magic packets") -} diff --git a/pkg/magicpacket/listen.go b/pkg/magicpacket/listen.go index 4c0918a..af2be21 100644 --- a/pkg/magicpacket/listen.go +++ b/pkg/magicpacket/listen.go @@ -1,14 +1,18 @@ package magicpacket import ( - "bytes" "fmt" "net" ) -func Listen(port int) (*net.UDPAddr, string, error) { +func Listen(ip string, port int) (*net.UDPAddr, string, error) { + listenIP := net.ParseIP(ip) + if listenIP == nil { + return nil, "", fmt.Errorf("invalid IP: %s", ip) + } + addr := net.UDPAddr{ - IP: net.ParseIP("0.0.0.0"), + IP: listenIP, Port: port, } @@ -18,21 +22,11 @@ func Listen(port int) (*net.UDPAddr, string, error) { } defer listener.Close() - var magicPacket MagicPacket - remote := &net.UDPAddr{} - _, remote, err = listener.ReadFromUDP(magicPacket[:]) + magicPacket := MagicPacket{} + _, remote, err := listener.ReadFromUDP(magicPacket[:]) if err != nil { return remote, "", err } - macLength := 6 - offset := 6 - for i := 0; i < 16; i++ { - if !bytes.Equal(magicPacket[offset:offset+macLength], magicPacket[96:]) { - return remote, "", fmt.Errorf("received malformed magicpacket from %v", remote) - } - offset += 6 - } - - return remote, net.HardwareAddr.String(magicPacket[96:]), err + return remote, magicPacket.Mac(), magicPacket.Validate() } diff --git a/pkg/magicpacket/magicpacket.go b/pkg/magicpacket/magicpacket.go index bbf6f8d..b9afafe 100644 --- a/pkg/magicpacket/magicpacket.go +++ b/pkg/magicpacket/magicpacket.go @@ -1,31 +1,62 @@ package magicpacket import ( + "bytes" + "errors" "fmt" "net" ) +var ( + ErrNotValidEUI48MacAddress = errors.New("not a valid EUI-48 MAC address") + ErrMalformedMagicPacket = errors.New("malformed magic packet") +) + type MagicPacket [102]byte func New(mac string) (MagicPacket, error) { - // Parse mac address hwAddr, err := net.ParseMAC(mac) if err != nil { return MagicPacket{}, err } if len(hwAddr) != 6 { - return MagicPacket{}, fmt.Errorf("invalid EUI-48 MAC address") + return MagicPacket{}, fmt.Errorf("%w, %s", ErrNotValidEUI48MacAddress, hwAddr) } - // Build magicpacket magicPacket := MagicPacket{255, 255, 255, 255, 255, 255} - offset := 6 for i := 0; i < 16; i++ { copy(magicPacket[offset:], hwAddr[:]) offset += 6 } - return magicPacket, err + return magicPacket, nil +} + +func (mp MagicPacket) Bytes() []byte { + return mp[:] +} + +func (mp MagicPacket) Mac() string { + return net.HardwareAddr.String(mp[96:]) +} + +func (mp MagicPacket) Validate() error { + macLength := 6 + offset := 6 + for i := 0; i < 16; i++ { + if !bytes.Equal(mp[offset:offset+macLength], mp[96:]) { + return ErrMalformedMagicPacket + } + + offset += 6 + } + + _, err := net.ParseMAC(mp.Mac()) + if err != nil { + return err + } + + return nil } diff --git a/pkg/magicpacket/magicpacket_test.go b/pkg/magicpacket/magicpacket_test.go new file mode 100644 index 0000000..a46a322 --- /dev/null +++ b/pkg/magicpacket/magicpacket_test.go @@ -0,0 +1,63 @@ +package magicpacket_test + +import ( + "bytes" + "errors" + "testing" + + "github.com/jedrw/gowake/pkg/magicpacket" +) + +func TestNewErrorsWithInvalidMacAddress(t *testing.T) { + _, err := magicpacket.New("ab:ab:ab:ab:ab:ag") + if err == nil { + t.Errorf("expected error, got nil") + } +} + +func TestNewErrorsWithNonEUI48MacAddress(t *testing.T) { + _, err := magicpacket.New("02:00:5e:10:00:00:00:01") + if !errors.Is(err, magicpacket.ErrNotValidEUI48MacAddress) { + t.Errorf("expected %s, Got: %s", magicpacket.ErrNotValidEUI48MacAddress, err) + } +} + +func TestNewReturnsValidMagicPacket(t *testing.T) { + magicPacket, err := magicpacket.New("ab:ab:ab:ab:ab:ab") + if err != nil { + t.Error(err) + } + + expectedBytes := []byte{255, 255, 255, 255, 255, 255, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171} + if !bytes.Equal(magicPacket.Bytes(), expectedBytes) { + t.Errorf("expected: %b, got: %b", expectedBytes, magicPacket.Bytes()) + } +} + +func TestMagicPacketMac(t *testing.T) { + magicPacket, err := magicpacket.New("ab:ab:ab:ab:ab:ab") + if err != nil { + t.Error(err) + } + + expectedMac := "ab:ab:ab:ab:ab:ab" + if magicPacket.Mac() != expectedMac { + t.Errorf("expected: %s, got: %s", expectedMac, magicPacket.Mac()) + } +} + +func TestMagicPacketValidateReturnsErrorOnMalformedMagicPacket(t *testing.T) { + magicPacket := magicpacket.MagicPacket{255, 255, 255, 255, 255, 255} + offset := 6 + for i := 0; i < 15; i++ { + copy(magicPacket[offset:], []byte{171, 171, 171, 171, 171, 171}) + offset += 6 + } + + copy(magicPacket[offset:], []byte{172, 172, 172, 172, 172, 172}) + + err := magicPacket.Validate() + if !errors.Is(err, magicpacket.ErrMalformedMagicPacket) { + t.Errorf("expected %s, Got: %s", magicpacket.ErrMalformedMagicPacket, err) + } +} diff --git a/pkg/magicpacket/send.go b/pkg/magicpacket/send.go index 7749685..6f47909 100644 --- a/pkg/magicpacket/send.go +++ b/pkg/magicpacket/send.go @@ -5,14 +5,24 @@ import ( "net" ) -func Send(packet MagicPacket, ip string, port int) error { - conn, err := net.Dial("udp", fmt.Sprintf("%s:%d", ip, port)) +func Send(magicPacket MagicPacket, ip string, port int) error { + sendIP := net.ParseIP(ip) + if sendIP == nil { + return fmt.Errorf("invalid IP: %s", ip) + } + + addr := net.UDPAddr{ + IP: sendIP, + Port: port, + } + + conn, err := net.Dial("udp", addr.String()) if err != nil { return err } defer conn.Close() - _, err = conn.Write(packet[:]) + _, err = conn.Write(magicPacket.Bytes()) if err != nil { return err }