""" An RPython implementation of sockets based on ctypes. Hacked up version of arre's from the pypy project. """ from ctypes import create_string_buffer, sizeof, cast from ctypes import POINTER, c_char, c_char_p, pointer, byref, c_void_p from pypy.rlib.rarithmetic import intmask from pypy.rlib.objectmodel import instantiate from pypy.rpython.rctypes.astruct import offsetof from mtt.rlib import libc as _c from mtt.rlib.rctypes import copy_buffer from mtt.rlib.io import buffers # Define constants from _c constants = _c.constants locals().update(constants) ntohs = _c.ntohs ntohl = _c.ntohl htons = _c.htons htonl = _c.htonl _FAMILIES = {} class Address(object): """The base class for RPython-level objects representing addresses. Fields: addr - a _c.sockaddr structure addrlen - size used within 'addr' """ class __metaclass__(type): def __new__(cls, name, bases, dict): family = dict.get('family') A = type.__new__(cls, name, bases, dict) if family is not None: _FAMILIES[family] = A return A def __init__(self, addr, addrlen): self.addr = addr self.addrlen = addrlen def toString(self): ' XXX booo! ' return self.__repr__() # ____________________________________________________________ class SocketError(Exception): def __init__(self, message): self.message = message def getMsg(self): return self.message class CSocketError(SocketError): def __init__(self, errno, msg): self.errno = errno msg = msg + (' : [%s] ' % errno) + _c.getErrorStr(errno) SocketError.__init__(self, msg) def raiseCError(msg): raise CSocketError(_c.geterrno(), msg) # ____________________________________________________________ def makeipaddr(name, result=None): # Convert a string specifying a host name or one of a few symbolic # names to an IPAddress instance. This usually calls getaddrinfo() # to do the work; the names "" and "" are special. # If 'result' is specified it must be a prebuilt INETAddress or # INET6Address that is filled; otherwise a new INETXAddress is returned. if result is None: family = AF_UNSPEC else: family = result.family if len(name) == 0: hints = _c.addrinfo(ai_family = family, ai_socktype = SOCK_DGRAM, # dummy ai_flags = AI_PASSIVE) res = _c.addrinfo_ptr() error = _c.getaddrinfo(None, "0", byref(hints), byref(res)) if error: raise CSocketError(error, 'make ip address') try: info = res.contents if info.ai_next: raise SocketError("wildcard resolved to " "multiple addresses") return makeAddress(info.ai_addr, info.ai_addrlen, result) finally: _c.freeaddrinfo(res) # IPv4 also supports the special name "". if name == '': return makeipv4addr(intmask(INADDR_BROADCAST), result) # "dd.dd.dd.dd" format. digits = name.split('.') if len(digits) == 4: try: d0 = int(digits[0]) d1 = int(digits[1]) d2 = int(digits[2]) d3 = int(digits[3]) except ValueError: pass else: if (0 <= d0 <= 255 and 0 <= d1 <= 255 and 0 <= d2 <= 255 and 0 <= d3 <= 255): return makeipv4addr(intmask(htonl( (intmask(d0 << 24)) | (d1 << 16) | (d2 << 8) | (d3 << 0))), result) # generic host name to IP conversion hints = _c.addrinfo(ai_family = family) res = _c.addrinfo_ptr() error = _c.getaddrinfo(name, None, byref(hints), byref(res)) # PLAT EAI_NONAME if error: raise CSocketError(error, 'PLAT EAI_NONAME') try: info = res.contents return makeAddress(info.ai_addr, info.ai_addrlen, result) finally: _c.freeaddrinfo(res) class IPAddress(Address): """AF_INET and AF_INET6 addresses""" def getHost(self): # Create a string object representing an IP address. # For IPv4 this is always a string of the form 'dd.dd.dd.dd' # (with variable size numbers). buf = create_string_buffer(NI_MAXHOST) error = _c.getnameinfo(byref(self.addr), self.addrlen, buf, NI_MAXHOST, None, 0, NI_NUMERICHOST) if error: raise CSocketError(error, 'getnameinfo') return buf.value # ____________________________________________________________ class INETAddress(IPAddress): family = AF_INET struct = _c.sockaddr_in maxlen = sizeof(struct) def __init__(self, host, port): makeipaddr(host, self) a = self.asSockaddrIn() a.sin_port = htons(port) def asSockaddrIn(self): if self.addrlen != INETAddress.maxlen: raise SocketError("invalid address") return cast(pointer(self.addr), POINTER(_c.sockaddr_in)).contents def __repr__(self): try: return '' % (self.getHost(), self.getPort()) except SocketError: return '' def getPort(self): a = self.asSockaddrIn() return ntohs(a.sin_port) def fromInAddr(in_addr): sin = _c.sockaddr_in(sin_family = AF_INET) # PLAT sin_len sin.sin_addr = in_addr paddr = cast(pointer(sin), _c.sockaddr_ptr) result = instantiate(INETAddress) result.addr = paddr.contents result.addrlen = sizeof(_c.sockaddr_in) return result fromInAddr = staticmethod(fromInAddr) def extract_in_addr(self): p = cast(pointer(self.asSockaddrIn().sin_addr), c_void_p) return p, sizeof(_c.in_addr) # ____________________________________________________________ class UNIXAddress(Address): family = AF_UNIX struct = _c.sockaddr_un maxlen = sizeof(struct) def __init__(self, path): sun = _c.sockaddr_un(sun_family = AF_UNIX) if _c.linux and path.startswith('\x00'): # Linux abstract namespace extension if len(path) > sizeof(sun.sun_path): raise SocketError("AF_UNIX path too long") else: # regular NULL-terminated string if len(path) >= sizeof(sun.sun_path): raise SocketError("AF_UNIX path too long") sun.sun_path[len(path)] = 0 for i in range(len(path)): sun.sun_path[i] = ord(path[i]) self.sun = sun # <----- keep a ref so the GC doesn't come here :P self.addr = cast(pointer(sun), _c.sockaddr_ptr).contents self.addrlen = offsetof(_c.sockaddr_un, 'sun_path') + len(path) def asSockAddrUn(self): if self.addrlen <= offsetof(_c.sockaddr_un, 'sun_path'): raise SocketError("invalid address") return cast(pointer(self.addr), POINTER(_c.sockaddr_un)).contents def __repr__(self): try: return '' % (self.getPath(),) except SocketError: return '' def getPath(self): a = self.asSockAddrUn() if _c.linux and a.sun_path[0] == 0: # Linux abstract namespace buf = copy_buffer(cast(pointer(a.sun_path), POINTER(c_char)), self.addrlen - offsetof(_c.sockaddr_un, 'sun_path')) return buf.raw else: # regular NULL-terminated string return cast(pointer(a.sun_path), c_char_p).value def eq(self, other): # __eq__() is not called by RPython :-/ return (isinstance(other, UNIXAddress) and self.getPath() == other.getPath()) # ____________________________________________________________ def familyclass(family): return _FAMILIES.get(family, Address) af_get = familyclass # XXX _dont_gc_meX: we need to keep these around? # XXX we need to feed this back to pypy def makeAddress(addrptr, addrlen, result=None): family = addrptr.contents.sa_family if result is None: result = instantiate(familyclass(family)) elif result.family != family: raise SocketError("address family mismatched") buf = result._dont_gc_me0 = copy_buffer(cast(addrptr, POINTER(c_char)), addrlen) result.addr = cast(buf, _c.sockaddr_ptr).contents result.addrlen = addrlen return result def makeipv4addr(s_addr, result=None): if result is None: result = instantiate(INETAddress) elif result.family != AF_INET: raise SocketError("address family mismatched") sin = _c.sockaddr_in(sin_family = AF_INET) # PLAT sin_len sin.sin_addr.s_addr = s_addr result._dont_gc_me1 = sin paddr = cast(pointer(sin), _c.sockaddr_ptr) result.addr = paddr.contents result.addrlen = sizeof(_c.sockaddr_in) return result def makeNullAddress(family): klass = familyclass(family) buf = create_string_buffer(klass.maxlen) result = instantiate(klass) result._dont_gc_me2 = buf result.addr = cast(buf, _c.sockaddr_ptr).contents result.addrlen = 0 return result, len(buf) # ____________________________________________________________ class Socket(object): """ RPython-level socket object. """ address = None bound = False listening = False fd = _c.INVALID_SOCKET def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0): """Create a new socket.""" fd = _c.socket(family, type, proto) if _c.invalid_socket(fd): raiseCError("in Socket init()") self.fd = fd self.family = family self.type = type self.proto = proto def valid(self): fd = self.fd if fd != _c.INVALID_SOCKET: if not _c.invalid_socket(fd): return True self.fd = _c.INVALID_SOCKET return False def checkValid(self): " raises if invalid " if not self.valid(): raise SocketError("socket invalid") def fileno(self): if not self.valid(): raise SocketError("socket already closed") return self.fd def shutdown(self, how): """Shut down the reading side of the socket (flag == SHUT_RD), the writing side of the socket (flag == SHUT_WR), or both ends (flag == SHUT_RDWR).""" res = _c.socketshutdown(self.fd, how) if res < 0: raiseCError('in shutdown') def bind(self, address): """Bind the socket to a local address.""" res = _c.socketbind(self.fd, byref(address.addr), address.addrlen) if res < 0: raiseCError('in bind') def close(self): if self.valid(): res = _c.socketclose(self.fd) if res != 0: raiseCError('in close') def setblocking(self, block): delay_flag = _c.fcntl(self.fd, _c.F_GETFL, 0) if block: delay_flag &= ~_c.O_NONBLOCK else: delay_flag |= _c.O_NONBLOCK _c.fcntl(self.fd, _c.F_SETFL, delay_flag) def setTcpNoDelay(self, flag): " if set will disable the Nagle algorithm. " self.setsockopt(_c.IPPROTO_TCP, _c.TCP_NODELAY, flag) def setReceiveBufferSize(self, size): " attempts to set the size of the buffer " self.setsockopt(_c.SOL_SOCKET, _c.SO_RCVBUF, 128 * 1024) def setSendBufferSize(self, size): " attempts to set the size of the buffer " self.setsockopt(_c.SOL_SOCKET, _c.SO_SNDBUF, 128 * 1024) def getsockopt(self, level, option): flag = _c.c_int() flagsize = _c.socklen_t() flagsize.value = _c.sizeof(flag) res = _c.socketgetsockopt(self.fd, level, option, byref(flag), byref(flagsize)) if res < 0: raiseCError('in getsockopt') return flag.value def setsockopt(self, level, option, value): flag = _c.c_int(value) res = _c.socketsetsockopt(self.fd, level, option, byref(flag), _c.sizeof(flag)) if res < 0: raiseCError('in setsockopt') def getpeername(self): """Return the address of the remote endpoint.""" address, addrlen = self._addrbuf() res = _c.socketgetpeername(self.fd, byref(address.addr), byref(addrlen)) if res < 0: raiseCError('in getpeername') address.addrlen = addrlen.value return address def _addrbuf(self): addr, maxlen = makeNullAddress(self.family) return addr, _c.socklen_t(maxlen) def recv(self, buf, buffersize, flags=0): read_bytes = _c.socketrecv(self.fd, buf, buffersize, flags) if read_bytes < 0: raiseCError('in recv') return read_bytes def recvfrom(self, buf, buffersize, flags=0): address, addrlen = self._addrbuf() read_bytes = _c.recvfrom(self.fd, buf, buffersize, flags, byref(address.addr), byref(addrlen)) if read_bytes < 0: raiseCError('in recvfrom') result_addrlen = addrlen.value if result_addrlen: address.addrlen = result_addrlen else: address = None return read_bytes, address def send(self, buf, bufsize, flags=0): res = _c.send(self.fd, buf, bufsize, flags) if res < 0: raiseCError('in send') return res def sendto(self, buf, bufsize, address, flags=0): res = _c.sendto(self.fd, buf, bufsize, flags, byref(address.addr), address.addrlen) if res < 0: raiseCError('in sendto') return res # ____________________________________________________________ class AcceptSocket(Socket): def bind(self, address): """Bind the socket to a local address.""" assert not self.bound self.setsockopt(_c.SOL_SOCKET, _c.SO_REUSEADDR, 1) res = _c.socketbind(self.fd, byref(address.addr), address.addrlen) if res < 0: raise raiseCError('in AcceptSocket.bind()') self.bound = True def listen(self, backlog): """Enable a server to accept connections. The backlog argument must be at least 1; it specifies the number of unaccepted connections that the system will allow before refusing new connections.""" assert self.bound, "socket needs to be bound first" if backlog < 1: backlog = 1 res = _c.socketlisten(self.fd, backlog) if res < 0: raise raiseCError('in AcceptSocket.listen()') self.listening = True def accept(self): """Wait for an incoming connection. Returns a ServerSocket.""" assert self.listening, "socket needs to be listening first" address, addrlen = self._addrbuf() newfd = _c.socketaccept(self.fd, byref(address.addr), byref(addrlen)) if newfd == -1 and _c.geterrno() == _c.EWOULDBLOCK: return None if _c.invalid_socket(newfd): raise raiseCError('in AcceptSocket.accept()') address.addrlen = addrlen.value # create a new socket for channel based io return makeSocket(newfd, self.family, self.type, self.proto, SocketClass=ReadWriteSocket) class ReadWriteSocket(Socket): " minimal interface for channel based io " def _readinto(self, rawbufptr, size, flags=0): readbytes = _c.socketrecv(self.fd, rawbufptr, size, flags) if readbytes < 0: err = _c.geterrno() if err == _c.EWOULDBLOCK: readbytes = 0 return readbytes def _writefrom(self, rawbufptr, size, flags=0): writebytes = _c.send(self.fd, rawbufptr, size, flags) if writebytes < 0: err = _c.geterrno() if err == _c.EWOULDBLOCK: writebytes = 0 return writebytes class ConnectSocket(ReadWriteSocket): def connect(self, address): """Connect the socket to a remote address.""" self.checkValid() res = _c.socketconnect(self.fd, byref(address.addr), address.addrlen) if res < 0: posserr = _c.geterrno() if posserr == _c.EALREADY or posserr == _c.EINPROGRESS: res = 0 else: raise raiseCError('in ConnectSocket.connect()') self.address = address return res def finishConnect(self): """ will return true if connection established, or false in still in progess, and raises if there was any error connecting. This seems like the only sane semantics. """ assert self.address is not None res = _c.socketconnect(self.fd, byref(self.address.addr), self.address.addrlen) if res < 0: posserr = _c.geterrno() if posserr == _c.EISCONN: return True elif posserr == _c.EALREADY or posserr == _c.EINPROGRESS: return False else: raise raiseCError('in ConnectSocket.finishConnect()') return False def makeSocket(fd, family, type, proto, SocketClass=Socket): sock = instantiate(SocketClass) sock.fd = fd sock.family = family sock.type = type sock.proto = proto sock.checkValid() return sock makeSocket._annspecialcase_ = 'specialize:arg(4)' # ____________________________________________________________ def gethostname(): buf = create_string_buffer(1024) res = _c.gethostname(buf, sizeof(buf)-1) if res < 0: raise CSocketError(_c.geterrno(), 'gethostname()') buf[sizeof(buf)-1] = '\x00' return buf.value def gethostbyname(name): # this is explicitly not working with IPv6, because the docs say it # should not. Just use makeipaddr(name) for an IPv6-friendly version... result = instantiate(INETAddress) makeipaddr(name, result) return result