首页 > 代码库 > tcpsock for Golang

tcpsock for Golang

  前记:本文所述的 tcpsock 库托管在 Github。

 

  Golang 中的 net 标准库已对 TCP 网络编程作了简洁(却很不简单)的封装,基本上,可直接通过引用其提供的相关接口开发简易的网络应用程序。但我想封装一套简单并提供简洁调用接口的 TCP 网络库(即 tcpsock),以达到所谓组件化的目的,如此用户只需调用几个非常简单的接口(包括设置回调函数等)就能较直观地处理网络收发等逻辑。

  一开始设计 tcpsock 时,只打算提供 TcpServer 及 TcpConn 对象而没有考虑加入 TcpClient(毕竟 Golang 多用于后台开发),但譬如网络游戏中的网关(GameGate),其既是服务器程序(服务于数计的网络客户端),又要充当客户端连接到游戏服务器(GameServer),此时即凸显 TcpClient 存在的必要了。所以重构了部分结构与代码,加入了 TcpClient 对象,而这也带来了额外的好处,具体可参考 chatroom 示例,其服务器与客户端在诸如协议解析等方面完全复用代码,大大减少了开发维护成本。

  囿于个人当前的能力及精力, tcpsock 不可避免会存在一些甚至许多问题或 Bug(虽然它确实很简单),希望能在实际的项目运用及与技术同仁的探讨中一点点改进它,直至完善。

  tcpsock 库共 4 个单元(tcpsock、tcpserver、tcpconn 及 tcpclient),并提供 3 个导出类型(TcpServer、TcpClient 及 TcpConn,分别实现于与其名称对应的单元),各单元代码如下:

// Copyright (C) 2017 ecofast(胡光耀). All rights reserved.
// Use of this source code is governed by a BSD-style license.

// Package tcpsock provides easy to use interfaces for TCP I/O.
// Thanks to darksword(gansidui) and AlexStocks for their valuable projects
// which are gotcp(https://github.com/gansidui/gotcp)
// and getty(https://github.com/AlexStocks/getty).
package tcpsock

import (
	"sync"
)

const (
	RecvBufLenMax = 4 * 1024
	SendBufLenMax = 4 * 1024

	SendBufCapMax = 10
	RecvBufCapMax = 10
)

type tcpSock struct {
	sendBufCap    uint32
	recvBufCap    uint32
	proto         Protocol
	exitChan      chan struct{}
	waitGroup     *sync.WaitGroup
	onConnConnect OnTcpConnCallback
	onConnClose   OnTcpConnCallback
}

type Protocol interface {
	Parse(b []byte, recvChan chan<- Packet)
	Process(conn *TcpConn, p Packet)
}

type Packet interface {
	Marshal() []byte
}
// Copyright (C) 2017 ecofast(胡光耀). All rights reserved.
// Use of this source code is governed by a BSD-style license.

package tcpsock

import (
	"net"
	"sync"
	"sync/atomic"
	"time"

	. "github.com/ecofast/sysutils"
)

type TcpServer struct {
	listener      *net.TCPListener
	acceptTimeout int
	*tcpSock
	autoIncID uint32
	numOfConn uint32
}

func NewTcpServer(listenPort, acceptTimeout int, protocol Protocol) *TcpServer {
	tcpAddr, err := net.ResolveTCPAddr("tcp", ":"+IntToStr(int(listenPort)))
	CheckError(err)
	listener, err := net.ListenTCP("tcp", tcpAddr)
	CheckError(err)

	return &TcpServer{
		listener:      listener,
		acceptTimeout: acceptTimeout,
		tcpSock: &tcpSock{
			sendBufCap: SendBufCapMax,
			recvBufCap: RecvBufCapMax,
			proto:      protocol,
			exitChan:   make(chan struct{}),
			waitGroup:  &sync.WaitGroup{},
		},
	}
}

func (self *TcpServer) Serve() {
	self.waitGroup.Add(1)
	defer func() {
		self.listener.Close()
		self.waitGroup.Done()
	}()

	for {
		select {
		case <-self.exitChan:
			return

		default:
		}

		self.listener.SetDeadline(time.Now().Add(time.Duration(self.acceptTimeout) * time.Second))
		conn, err := self.listener.AcceptTCP()
		if err != nil {
			continue
		}

		atomic.AddUint32(&self.numOfConn, 1)
		self.waitGroup.Add(1)
		go func() {
			c := newTcpConn(atomic.AddUint32(&self.autoIncID, 1), self.tcpSock, conn, self.sendBufCap, self.recvBufCap, self.connClose)
			if self.onConnConnect != nil {
				self.onConnConnect(c)
			}
			c.run()
			self.waitGroup.Done()
		}()
	}
}

func (self *TcpServer) Close() {
	close(self.exitChan)
	self.waitGroup.Wait()
}

func (self *TcpServer) NumOfConn() uint32 {
	return atomic.LoadUint32(&self.numOfConn)
}

func (self *TcpServer) connClose(conn *TcpConn) {
	atomic.AddUint32(&self.numOfConn, ^uint32(0))
	if self.onConnClose != nil {
		self.onConnClose(conn)
	}
}

func (self *TcpServer) OnConnConnect(fn OnTcpConnCallback) {
	self.onConnConnect = fn
}

func (self *TcpServer) OnConnClose(fn OnTcpConnCallback) {
	self.onConnClose = fn
}
// Copyright (C) 2017 ecofast(胡光耀). All rights reserved.
// Use of this source code is governed by a BSD-style license.

package tcpsock

import (
	"net"
	"sync"
	"sync/atomic"
)

type OnTcpConnCallback func(c *TcpConn)

type TcpConn struct {
	id         uint32
	owner      *tcpSock
	conn       *net.TCPConn
	sendChan   chan Packet
	recvChan   chan Packet
	closeChan  chan struct{}
	closeOnce  sync.Once
	closedFlag int32
	onClose    OnTcpConnCallback
}

func newTcpConn(id uint32, owner *tcpSock, conn *net.TCPConn, sendCap, recvCap uint32, onClose OnTcpConnCallback) *TcpConn {
	return &TcpConn{
		id:        id,
		owner:     owner,
		conn:      conn,
		sendChan:  make(chan Packet, sendCap),
		recvChan:  make(chan Packet, recvCap),
		closeChan: make(chan struct{}),
		onClose:   onClose,
	}
}

func (self *TcpConn) ID() uint32 {
	return self.id
}

func (self *TcpConn) run() {
	startGoroutine(self.reader, self.owner.waitGroup)
	startGoroutine(self.writer, self.owner.waitGroup)
	startGoroutine(self.handler, self.owner.waitGroup)
}

func (self *TcpConn) Close() {
	self.closeOnce.Do(func() {
		atomic.StoreInt32(&self.closedFlag, 1)
		close(self.sendChan)
		close(self.recvChan)
		close(self.closeChan)
		self.conn.Close()
		if self.onClose != nil {
			self.onClose(self)
		}
	})
}

func (self *TcpConn) Closed() bool {
	return atomic.LoadInt32(&self.closedFlag) == 1
}

func (self *TcpConn) RawConn() *net.TCPConn {
	return self.conn
}

func startGoroutine(fn func(), wg *sync.WaitGroup) {
	wg.Add(1)
	go func() {
		fn()
		wg.Done()
	}()
}

func (self *TcpConn) reader() {
	defer func() {
		recover()
		self.Close()
	}()

	buf := make([]byte, RecvBufLenMax)
	for {
		select {
		case <-self.owner.exitChan:
			return

		case <-self.closeChan:
			return

		default:
		}

		count, err := self.conn.Read(buf)
		if err != nil {
			return
		}
		self.owner.proto.Parse(buf[:count], self.recvChan)
	}
}

func (self *TcpConn) writer() {
	defer func() {
		recover()
		self.Close()
	}()

	for {
		if self.Closed() {
			return
		}

		select {
		case <-self.owner.exitChan:
			return

		case <-self.closeChan:
			return

		case p := <-self.sendChan:
			if _, err := self.conn.Write(p.Marshal()); err != nil {
				return
			}
		}
	}
}

func (self *TcpConn) handler() {
	defer func() {
		recover()
		self.Close()
	}()

	for {
		if self.Closed() {
			return
		}

		select {
		case <-self.owner.exitChan:
			return

		case <-self.closeChan:
			return

		case packet := <-self.recvChan:
			self.owner.proto.Process(self, packet)
		}
	}
}

func (self *TcpConn) Write(p Packet) {
	if self.Closed() {
		return
	}

	defer func() {
		recover()
	}()

	self.sendChan <- p
}
// Copyright (C) 2017 ecofast(胡光耀). All rights reserved.
// Use of this source code is governed by a BSD-style license.

package tcpsock

import (
	"net"
	"sync"

	. "github.com/ecofast/sysutils"
)

type TcpClient struct {
	svrAddr *net.TCPAddr
	*tcpSock
}

func NewTcpClient(svrAddr string, proto Protocol) *TcpClient {
	tcpAddr, err := net.ResolveTCPAddr("tcp", svrAddr)
	CheckError(err)
	return &TcpClient{
		svrAddr: tcpAddr,
		tcpSock: &tcpSock{
			sendBufCap: SendBufCapMax,
			recvBufCap: RecvBufCapMax,
			proto:      proto,
			exitChan:   make(chan struct{}),
			waitGroup:  &sync.WaitGroup{},
		},
	}
}

func (self *TcpClient) Run() {
	conn, err := net.DialTCP("tcp", nil, self.svrAddr)
	CheckError(err)

	self.waitGroup.Add(1)
	go func() {
		// client sock do NOT need to identify self
		c := newTcpConn( /*atomic.AddUint32(&self.autoIncID, 1)*/ 0, self.tcpSock, conn, self.sendBufCap, self.recvBufCap, self.connClose)
		if self.onConnConnect != nil {
			self.onConnConnect(c)
		}
		c.run()
		self.waitGroup.Done()
	}()
}

func (self *TcpClient) Close() {
	close(self.exitChan)
	self.waitGroup.Wait()
}

func (self *TcpClient) OnConnect(fn OnTcpConnCallback) {
	self.onConnConnect = fn
}

func (self *TcpClient) OnClose(fn OnTcpConnCallback) {
	self.onConnClose = fn
}

func (self *TcpClient) connClose(conn *TcpConn) {
	if self.onConnClose != nil {
		self.onConnClose(conn)
	}
}

  可以看到,tcpsock 的实现中有几处地方使用到了回调函数,或许后面会有更好的封装及处理方法(当然这绝不是说回调函数不好),暂时先这样设计吧。

  而我也编写了个简单的包括服务器与客户端的聊天室程序,用以测试及验证 tcpsock,虽然这个示例挺简单,但可能比较实用,包括不限于如何使用 tcpsock,及二进制协议的设计与解析(涉及到 TCP 粘包处理)等,代码如下:

package main

import (
	"log"
	"os"
	"os/signal"
	"sync"
	"syscall"
	"tcpsock"
	. "tcpsock/samples/chatroom/protocol"
	"time"

	. "github.com/ecofast/sysutils"
)

var (
	shutdown = make(chan bool, 1)

	mutex   sync.Mutex
	clients map[uint32]*tcpsock.TcpConn
)

func init() {
	signals := make(chan os.Signal, 1)
	signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
	go func() {
		<-signals
		shutdown <- true
	}()
}

func onConnConnect(conn *tcpsock.TcpConn) {
	// conn.Send(genChatPacket())
	mutex.Lock()
	defer mutex.Unlock()
	clients[conn.ID()] = conn
}

func onConnClose(conn *tcpsock.TcpConn) {
	mutex.Lock()
	defer mutex.Unlock()
	delete(clients, conn.ID())
}

func genChatPacket() *ChatPacket {
	var head PacketHead
	head.Signature = ChatSignature
	head.PlayerID = 555555555
	s := "current time is " + TimeToStr(time.Now())
	head.BodyLen = uint32(len(s))
	body := make([]byte, int(head.BodyLen))
	copy(body[:], []byte(s)[:])
	return NewChatPacket(head, body)
}

func broadcast() {
	mutex.Lock()
	defer mutex.Unlock()
	packet := genChatPacket()
	for _, c := range clients {
		c.Write(packet)
	}
}

func onMsg(conn *tcpsock.TcpConn, p *ChatPacket) {
	mutex.Lock()
	defer mutex.Unlock()
	for _, c := range clients {
		c.Write(p)
	}
}

func main() {
	clients = make(map[uint32]*tcpsock.TcpConn)

	proto := &ChatProtocol{}
	proto.OnMessage(onMsg)
	server := tcpsock.NewTcpServer(9999, 2, proto)
	server.OnConnConnect(onConnConnect)
	server.OnConnClose(onConnClose)
	log.Println("=====service start=====")
	go server.Serve()

	ticker := time.NewTicker(10 * time.Second)
	go func() {
		for range ticker.C {
			log.Printf("num of conn: %d\n", server.NumOfConn())
			broadcast()
		}
	}()

	<-shutdown
	server.Close()
	log.Println("=====service end=====")
}
package protocol

import (
	"bytes"
	"encoding/binary"
	"tcpsock"
)

const (
	ChatSignature  = 0xFFFFFFFF
	PacketHeadSize = 4 + 4 + 4
)

type PacketHead struct {
	Signature uint32
	PlayerID  uint32
	BodyLen   uint32
}

func (head *PacketHead) Bytes() []byte {
	var buf bytes.Buffer
	binary.Write(&buf, binary.LittleEndian, head)
	return buf.Bytes()
}

type ChatPacket struct {
	PacketHead
	Body []byte
}

func NewChatPacket(head PacketHead, body []byte) *ChatPacket {
	return &ChatPacket{
		PacketHead: head,
		Body:       body,
	}
}

func (p *ChatPacket) Marshal() []byte {
	buf := make([]byte, PacketHeadSize+len(p.Body))
	copy(buf[:PacketHeadSize], p.PacketHead.Bytes()[:])
	copy(buf[PacketHeadSize:], p.Body[:])
	return buf
}

type ChatProtocol struct {
	recvBuf    []byte
	recvBufLen int
	onMsg      func(c *tcpsock.TcpConn, p *ChatPacket)
}

func (self *ChatProtocol) Parse(b []byte, recvChan chan<- tcpsock.Packet) {
	count := len(b)
	if count+self.recvBufLen > tcpsock.RecvBufLenMax {
		return
	}

	self.recvBuf = append(self.recvBuf, b[0:count]...)
	self.recvBufLen += count
	offsize := 0
	offset := 0
	var head PacketHead
	for self.recvBufLen-offsize > PacketHeadSize {
		offset = 0
		head.Signature = uint32(uint32(self.recvBuf[offsize+3])<<24 | uint32(self.recvBuf[offsize+2])<<16 | uint32(self.recvBuf[offsize+1])<<8 | uint32(self.recvBuf[offsize+0]))
		offset += 4
		head.PlayerID = uint32(uint32(self.recvBuf[offsize+offset+3])<<24 | uint32(self.recvBuf[offsize+offset+2])<<16 | uint32(self.recvBuf[offsize+offset+1])<<8 | uint32(self.recvBuf[offsize+offset+0]))
		offset += 4
		head.BodyLen = uint32(uint32(self.recvBuf[offsize+offset+3])<<24 | uint32(self.recvBuf[offsize+offset+2])<<16 | uint32(self.recvBuf[offsize+offset+1])<<8 | uint32(self.recvBuf[offsize+offset+0]))
		offset += 4
		if head.Signature == ChatSignature {
			pkglen := int(PacketHeadSize + head.BodyLen)
			if pkglen >= tcpsock.RecvBufLenMax {
				offsize = self.recvBufLen
				break
			}
			if offsize+pkglen > self.recvBufLen {
				break
			}

			recvChan <- NewChatPacket(head, self.recvBuf[offsize+offset:offsize+offset+int(head.BodyLen)])
			offsize += pkglen
		} else {
			offsize++
		}
	}

	self.recvBufLen -= offsize
	if self.recvBufLen > 0 {
		self.recvBuf = self.recvBuf[offsize : offsize+self.recvBufLen]
	} else {
		self.recvBuf = nil
	}
}

func (self *ChatProtocol) Process(conn *tcpsock.TcpConn, p tcpsock.Packet) {
	packet := p.(*ChatPacket)
	self.onMsg(conn, packet)
}

func (self *ChatProtocol) OnMessage(fn func(c *tcpsock.TcpConn, p *ChatPacket)) {
	self.onMsg = fn
}
package main

import (
	"fmt"
	"log"
	"os"
	"os/signal"
	"syscall"
	"tcpsock"
	. "tcpsock/samples/chatroom/protocol"
)

const (
	ServerAddr = ":9999"
)

var (
	shutdown = make(chan bool, 1)

	tcpConn *tcpsock.TcpConn
	id      uint32
)

func init() {
	signals := make(chan os.Signal, 1)
	signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
	go func() {
		<-signals
		shutdown <- true
	}()
}

func main() {
	genID()

	proto := &ChatProtocol{}
	proto.OnMessage(onMsg)
	client := tcpsock.NewTcpClient(ServerAddr, proto)
	client.OnConnect(onConnect)
	client.OnClose(onClose)
	go client.Run()
	go input()
	<-shutdown
	client.Close()
}

func onConnect(c *tcpsock.TcpConn) {
	log.Println("successfully connect to server", c.RawConn().RemoteAddr().String())
	tcpConn = c
}

func onClose(c *tcpsock.TcpConn) {
	log.Println("disconnect from server", c.RawConn().RemoteAddr().String())
	tcpConn = nil
}

func onMsg(c *tcpsock.TcpConn, p *ChatPacket) {
	log.Printf("%d: %s\n", p.PlayerID, string(p.Body))
}

func genID() {
	fmt.Printf("pls enter your id: ")
	fmt.Scan(&id)
	fmt.Println("your id is:", id)
}

func input() {
	s := ""
	for {
		if n, err := fmt.Scan(&s); n == 0 || err != nil {
			break
		}
		if tcpConn == nil {
			break
		}
		tcpConn.Write(genPacket(s))
	}
}

func genPacket(s string) *ChatPacket {
	var head PacketHead
	head.Signature = ChatSignature
	head.PlayerID = id
	head.BodyLen = uint32(len(s))
	body := make([]byte, head.BodyLen)
	copy(body[:], []byte(s)[:])
	return NewChatPacket(head, body)
}

  可以看到,服务器与客户端完全共用一套协议处理单元,且各自只需对收到的网络封包作处理即可。

tcpsock for Golang