package iputil import ( "encoding/binary" "golang.org/x/net/ipv4" ) //TODO: IPV6-WORK can probably delete this const ( // Need 96 bytes for the largest reject packet: // - 20 byte ipv4 header // - 8 byte icmpv4 header // - 68 byte body (60 byte max orig ipv4 header + 8 byte orig icmpv4 header) MaxRejectPacketSize = ipv4.HeaderLen + 8 + 60 + 8 ) func CreateRejectPacket(packet []byte, out []byte) []byte { if len(packet) < ipv4.HeaderLen || int(packet[0]>>4) != ipv4.Version { return nil } switch packet[9] { case 6: // tcp return ipv4CreateRejectTCPPacket(packet, out) default: return ipv4CreateRejectICMPPacket(packet, out) } } func ipv4CreateRejectICMPPacket(packet []byte, out []byte) []byte { ihl := int(packet[0]&0x0f) << 2 if len(packet) < ihl { // We need at least this many bytes for this to be a valid packet return nil } // ICMP reply includes original header and first 8 bytes of the packet packetLen := len(packet) if packetLen > ihl+8 { packetLen = ihl + 8 } outLen := ipv4.HeaderLen + 8 + packetLen if outLen > cap(out) { return nil } out = out[:outLen] ipHdr := out[0:ipv4.HeaderLen] ipHdr[0] = ipv4.Version<<4 | (ipv4.HeaderLen >> 2) // version, ihl ipHdr[1] = 0 // DSCP, ECN binary.BigEndian.PutUint16(ipHdr[2:], uint16(outLen)) // Total Length ipHdr[4] = 0 // id ipHdr[5] = 0 // . ipHdr[6] = 0 // flags, fragment offset ipHdr[7] = 0 // . ipHdr[8] = 64 // TTL ipHdr[9] = 1 // protocol (icmp) ipHdr[10] = 0 // checksum ipHdr[11] = 0 // . // Swap dest / src IPs copy(ipHdr[12:16], packet[16:20]) copy(ipHdr[16:20], packet[12:16]) // Calculate checksum binary.BigEndian.PutUint16(ipHdr[10:], tcpipChecksum(ipHdr, 0)) // ICMP Destination Unreachable icmpOut := out[ipv4.HeaderLen:] icmpOut[0] = 3 // type (Destination unreachable) icmpOut[1] = 3 // code (Port unreachable error) icmpOut[2] = 0 // checksum icmpOut[3] = 0 // . icmpOut[4] = 0 // unused icmpOut[5] = 0 // . icmpOut[6] = 0 // . icmpOut[7] = 0 // . // Copy original IP header and first 8 bytes as body copy(icmpOut[8:], packet[:packetLen]) // Calculate checksum binary.BigEndian.PutUint16(icmpOut[2:], tcpipChecksum(icmpOut, 0)) return out } func ipv4CreateRejectTCPPacket(packet []byte, out []byte) []byte { const tcpLen = 20 ihl := int(packet[0]&0x0f) << 2 outLen := ipv4.HeaderLen + tcpLen if len(packet) < ihl+tcpLen { // We need at least this many bytes for this to be a valid packet return nil } if outLen > cap(out) { return nil } out = out[:outLen] ipHdr := out[0:ipv4.HeaderLen] ipHdr[0] = ipv4.Version<<4 | (ipv4.HeaderLen >> 2) // version, ihl ipHdr[1] = 0 // DSCP, ECN binary.BigEndian.PutUint16(ipHdr[2:], uint16(outLen)) // Total Length ipHdr[4] = 0 // id ipHdr[5] = 0 // . ipHdr[6] = 0 // flags, fragment offset ipHdr[7] = 0 // . ipHdr[8] = 64 // TTL ipHdr[9] = 6 // protocol (tcp) ipHdr[10] = 0 // checksum ipHdr[11] = 0 // . // Swap dest / src IPs copy(ipHdr[12:16], packet[16:20]) copy(ipHdr[16:20], packet[12:16]) // Calculate checksum binary.BigEndian.PutUint16(ipHdr[10:], tcpipChecksum(ipHdr, 0)) // TCP RST tcpIn := packet[ihl:] var ackSeq, seq uint32 outFlags := byte(0b00000100) // RST // Set seq and ackSeq based on how iptables/netfilter does it in Linux: // - https://github.com/torvalds/linux/blob/v5.19/net/ipv4/netfilter/nf_reject_ipv4.c#L193-L221 inAck := tcpIn[13]&0b00010000 != 0 if inAck { seq = binary.BigEndian.Uint32(tcpIn[8:]) } else { inSyn := uint32((tcpIn[13] & 0b00000010) >> 1) inFin := uint32(tcpIn[13] & 0b00000001) // seq from the packet + syn + fin + tcp segment length ackSeq = binary.BigEndian.Uint32(tcpIn[4:]) + inSyn + inFin + uint32(len(tcpIn)) - uint32(tcpIn[12]>>4)<<2 outFlags |= 0b00010000 // ACK } tcpOut := out[ipv4.HeaderLen:] // Swap dest / src ports copy(tcpOut[0:2], tcpIn[2:4]) copy(tcpOut[2:4], tcpIn[0:2]) binary.BigEndian.PutUint32(tcpOut[4:], seq) binary.BigEndian.PutUint32(tcpOut[8:], ackSeq) tcpOut[12] = (tcpLen >> 2) << 4 // data offset, reserved, NS tcpOut[13] = outFlags // CWR, ECE, URG, ACK, PSH, RST, SYN, FIN tcpOut[14] = 0 // window size tcpOut[15] = 0 // . tcpOut[16] = 0 // checksum tcpOut[17] = 0 // . tcpOut[18] = 0 // URG Pointer tcpOut[19] = 0 // . // Calculate checksum csum := ipv4PseudoheaderChecksum(ipHdr[12:16], ipHdr[16:20], 6, tcpLen) binary.BigEndian.PutUint16(tcpOut[16:], tcpipChecksum(tcpOut, csum)) return out } func CreateICMPEchoResponse(packet, out []byte) []byte { // Return early if this is not a simple ICMP Echo Request //TODO: make constants out of these if !(len(packet) >= 28 && len(packet) <= 9001 && packet[0] == 0x45 && packet[9] == 0x01 && packet[20] == 0x08) { return nil } // We don't support fragmented packets if packet[7] != 0 || (packet[6]&0x2F != 0) { return nil } out = out[:len(packet)] copy(out, packet) // Swap dest / src IPs and recalculate checksum ipv4 := out[0:20] copy(ipv4[12:16], packet[16:20]) copy(ipv4[16:20], packet[12:16]) ipv4[10] = 0 ipv4[11] = 0 binary.BigEndian.PutUint16(ipv4[10:], tcpipChecksum(ipv4, 0)) // Change type to ICMP Echo Reply and recalculate checksum icmp := out[20:] icmp[0] = 0 icmp[2] = 0 icmp[3] = 0 binary.BigEndian.PutUint16(icmp[2:], tcpipChecksum(icmp, 0)) return out } // calculates the TCP/IP checksum defined in rfc1071. The passed-in // csum is any initial checksum data that's already been computed. // // based on: // - https://github.com/google/gopacket/blob/v1.1.19/layers/tcpip.go#L50-L70 func tcpipChecksum(data []byte, csum uint32) uint16 { // to handle odd lengths, we loop to length - 1, incrementing by 2, then // handle the last byte specifically by checking against the original // length. length := len(data) - 1 for i := 0; i < length; i += 2 { // For our test packet, doing this manually is about 25% faster // (740 ns vs. 1000ns) than doing it by calling binary.BigEndian.Uint16. csum += uint32(data[i]) << 8 csum += uint32(data[i+1]) } if len(data)%2 == 1 { csum += uint32(data[length]) << 8 } for csum > 0xffff { csum = (csum >> 16) + (csum & 0xffff) } return ^uint16(csum) } // based on: // - https://github.com/google/gopacket/blob/v1.1.19/layers/tcpip.go#L26-L35 func ipv4PseudoheaderChecksum(src, dst []byte, proto, length uint32) (csum uint32) { csum += (uint32(src[0]) + uint32(src[2])) << 8 csum += uint32(src[1]) + uint32(src[3]) csum += (uint32(dst[0]) + uint32(dst[2])) << 8 csum += uint32(dst[1]) + uint32(dst[3]) csum += proto csum += length & 0xffff csum += length >> 16 return csum }