Coroutines in Python

Tim Peters tim@ksr.com
Sun, 22 May 94 23:59:42 -0400


Those with threads might find the attached demo fun to play with. The "sync" module referred to is the earlier-posted module that implemented synchronization objects. BTW, this runs _really_ slow -- it's just a proof-of-concept kinda thing.

Guido, this is one case where a "thread object" (e.g., thread.self()) would really help -- each thread here wants to remember _which_ coroutine it's implementing, and there's no clear way to do that (it's done here via an ad hoc mixture of local variables, class attributes, and storing info in a dict).

amusedly y'rs - tim

Tim Peters tim@ksr.com not speaking for Kendall Square Research Corp

Module Coroutine:

# Coroutine implementation using Python threads.
#
# Combines ideas from Guido's Generator module, and from the coroutine
# features of Icon and Simula 67.
#
# To run a collection of functions as coroutines, you need to create
# a Coroutine object to control them:
#    co = Coroutine()
# and then 'create' a subsidiary object for each function in the
# collection:
#    cof1 = co.create(f1 [, arg1, arg2, ...]) # [] means optional,
#    cof2 = co.create(f2 [, arg1, arg2, ...]) #... not list
#    cof3 = co.create(f3 [, arg1, arg2, ...])
# etc.  The functions need not be distinct; 'create'ing the same
# function multiple times gives you independent instances of the
# function.
#
# To start the coroutines running, use co.tran on one of the create'd
# functions; e.g., co.tran(cof2).  The routine that first executes
# co.tran is called the "main coroutine".  It's special in several
# respects:  it existed before you created the Coroutine object; if any of
# the create'd coroutines exits (does a return, or suffers an unhandled
# exception), EarlyExit error is raised in the main coroutine; and the
# co.detach() method transfers control directly to the main coroutine
# (you can't use co.tran() for this because the main coroutine doesn't
# have a name ...).
#
# Coroutine objects support these methods:
#
# handle = .create(func [, arg1, arg2, ...])
#    Creates a coroutine for an invocation of func(arg1, arg2, ...),
#    and returns a handle ("name") for the coroutine so created.  The
#    handle can be used as the target in a subsequent .tran().
#
# .tran(target, data=None)
#    Transfer control to the create'd coroutine "target", optionally
#    passing it an arbitrary piece of data. To the coroutine A that does
#    the .tran, .tran acts like an ordinary function call:  another
#    coroutine B can .tran back to it later, and if it does A's .tran
#    returns the 'data' argument passed to B's tran.  E.g.,
#
#    in coroutine coA   in coroutine coC    in coroutine coB
#      x = co.tran(coC)   co.tran(coB)        co.tran(coA,12)
#      print x # 12
#
#    The data-passing feature is taken from Icon, and greatly cuts
#    the need to use global variables for inter-coroutine communication.
#
# .back( data=None )
#    The same as .tran(invoker, data=None), where 'invoker' is the
#    coroutine that most recently .tran'ed control to the coroutine
#    doing the .back.  This is akin to Icon's "&source".
#
# .detach( data=None )
#    The same as .tran(main, data=None), where 'main' is the
#    (unnameable!) coroutine that started it all.  'main' has all the
#    rights of any other coroutine:  upon receiving control, it can
#    .tran to an arbitrary coroutine of its choosing, go .back to
#    the .detach'er, or .kill the whole thing.
#
# .kill()
#    Destroy all the coroutines, and return control to the main
#    coroutine.  None of the create'ed coroutines can be resumed after a
#    .kill().  An EarlyExit exception does a .kill() automatically.  It's
#    a good idea to .kill() coroutines you're done with, since the
#    current implementation consumes a thread for each coroutine that
#    may be resumed.

import thread
import sync

class _CoEvent:
    def __init__(self, func):
        self.f = func
        self.e = sync.event()

    def __repr__(self):
        if self.f is None:
            return 'main coroutine'
        else:
            return 'coroutine for func ' + self.f.func_name

    def __hash__(self):
        return id(self)

    def __cmp__(x,y):
        return cmp(id(x), id(y))

    def resume(self):
        self.e.post()

    def wait(self):
        self.e.wait()
        self.e.clear()

Killed = 'Coroutine.Killed'
EarlyExit = 'Coroutine.EarlyExit'

class Coroutine:
    def __init__(self):
        self.active = self.main = _CoEvent(None)
        self.invokedby = {self.main: None}
        self.killed = 0
        self.value  = None
        self.terminated_by = None

    def create(self, func, *args):
        me = _CoEvent(func)
        self.invokedby[me] = None
        thread.start_new_thread(self._start, (me,) + args)
        return me

    def _start(self, me, *args):
        me.wait()
        if not self.killed:
            try:
                try:
                    apply(me.f, args)
                except Killed:
                    pass
            finally:
                if not self.killed:
                    self.terminated_by = me
                    self.kill()

    def kill(self):
        if self.killed:
            raise TypeError, 'kill() called on dead coroutines'
        self.killed = 1
        for coroutine in self.invokedby.keys():
            coroutine.resume()

    def back(self, data=None):
        return self.tran( self.invokedby[self.active], data )

    def detach(self, data=None):
        return self.tran( self.main, data )

    def tran(self, target, data=None):
        if not self.invokedby.has_key(target):
            raise TypeError, '.tran target ' + `target` + \
                             ' is not an active coroutine'
        if self.killed:
            raise TypeError, '.tran target ' + `target` + ' is killed'
        self.value = data
        me = self.active
        self.invokedby[target] = me
        self.active = target
        target.resume()

        me.wait()
        if self.killed:
            if self.main is not me:
                raise Killed
            if self.terminated_by is not None:
                raise EarlyExit, `self.terminated_by` + ' terminated early'

        return self.value

# end of module

Example 1:

# Coroutine example:  controlling multiple instances of a single function

from Coroutine import *

# fringe visits a nested list in inorder, and detaches for each non-list
# element; raises EarlyExit after the list is exhausted
def fringe( co, list ):
    for x in list:
        if type(x) is type([]):
            fringe(co, x)
        else:
            co.detach(x)

def printinorder( list ):
    co = Coroutine()
    f = co.create(fringe, co, list)
    try:
        while 1:
            print co.tran(f),
    except EarlyExit:
        pass
    print

printinorder([1,2,3])  # 1 2 3
printinorder([[[[1,[2]]],3]]) # ditto
x = [0, 1, [2, [3]], [4,5], [[[6]]] ]
printinorder(x) # 0 1 2 3 4 5 6

# fcmp lexicographically compares the fringes of two nested lists
def fcmp( l1, l2 ):
    co1 = Coroutine(); f1 = co1.create(fringe, co1, l1)
    co2 = Coroutine(); f2 = co2.create(fringe, co2, l2)
    while 1:
        try:
            v1 = co1.tran(f1)
        except EarlyExit:
            try:
                v2 = co2.tran(f2)
            except EarlyExit:
                return 0
            co2.kill()
            return -1
        try:
            v2 = co2.tran(f2)
        except EarlyExit:
            co1.kill()
            return 1
        if v1 != v2:
            co1.kill(); co2.kill()
            return cmp(v1,v2)

print fcmp(range(7), x)  #  0; fringes are equal
print fcmp(range(6), x)  # -1; 1st list ends early
print fcmp(x, range(6))  #  1; 2nd list ends early
print fcmp(range(8), x)  #  1; 2nd list ends early
print fcmp(x, range(8))  # -1; 1st list ends early
print fcmp([1,[[2],8]],
           [[[1],2],8])  #  0
print fcmp([1,[[3],8]],
           [[[1],2],8])  #  1
print fcmp([1,[[2],8]],
           [[[1],2],9])  # -1

# end of example


Example 2:

# Coroutine example:  general coroutine transfers
#
# The program is a variation of a Simula 67 program due to Dahl & Hoare,
# who in turn credit the original example to Conway.
#
# We have a number of input lines, terminated by a 0 byte.  The problem
# is to squash them together into output lines containing 72 characters
# each.  A semicolon must be added between input lines.  Runs of blanks
# and tabs in input lines must be squashed into single blanks.
# Occurrences of "**" in input lines must be replaced by "^".
#
# Here's a test case:

test = """\
   d    =   sqrt(b**2  -  4*a*c)
twoa    =   2*a
   L    =   -b/twoa
   R    =   d/twoa
  A1    =   L + R
  A2    =   L - R\0
"""

# The program should print:

# d = sqrt(b^2 - 4*a*c);twoa = 2*a; L = -b/twoa; R = d/twoa; A1 = L + R;
#A2 = L - R
#done

# getline: delivers the next input line to its invoker
# disassembler: grabs input lines from getline, and delivers them one
#    character at a time to squasher, also inserting a semicolon into
#    the stream between lines
# squasher:  grabs characters from disassembler and passes them on to
#    assembler, first replacing "**" with "^" and squashing runs of
#    whitespace
# assembler: grabs characters from squasher and packs them into lines
#    with 72 character each, delivering each such line to putline;
#    when it sees a null byte, passes the last line to putline and
#    then kills all the coroutines
# putline: grabs lines from assembler, and just prints them

from Coroutine import *

def getline(text):
    for line in string.splitfields(text, '\n'):
        co.back(line)

def disassembler():
    while 1:
        card = co.tran(cogetline)
        for i in range(len(card)):
            co.tran(cosquasher, card[i])
        co.tran(cosquasher, ';')

def squasher():
    while 1:
        ch = co.tran(codisassembler)
        if ch == '*':
            ch2 = co.tran(codisassembler)
            if ch2 == '*':
                ch = '^'
            else:
                co.tran(coassembler, ch)
                ch = ch2
        if ch in ' \t':
            while 1:
                ch2 = co.tran(codisassembler)
                if ch2 not in ' \t':
                    break
            co.tran(coassembler, ' ')
            ch = ch2
        co.tran(coassembler, ch)

def assembler():
    line = ''
    while 1:
        ch = co.tran(cosquasher)
        if ch == '\0':
            break
        if len(line) == 72:
            co.tran(coputline, line)
            line = ''
        line = line + ch
    line = line + ' ' * (72 - len(line))
    co.tran(coputline, line)
    co.kill()

def putline():
    while 1:
        line = co.tran(coassembler)
        print line

import string
co = Coroutine()
cogetline = co.create(getline, test)
coputline = co.create(putline)
coassembler = co.create(assembler)
codisassembler = co.create(disassembler)
cosquasher = co.create(squasher)

co.tran(coputline)
print 'done'

# end of example

>>> END OF MSG