source file: /opt/devel/celery/celery/tests/test_task.py
file stats: 119 lines, 118 executed: 99.2% covered
   1. import unittest
   2. import uuid
   3. import logging
   4. from StringIO import StringIO
   5. 
   6. from celery import task
   7. from celery import registry
   8. from celery.log import setup_logger
   9. from celery import messaging
  10. 
  11. 
  12. def get_string_io_logger():
  13.     sio = StringIO()
  14.     logger = setup_logger(loglevel=logging.INFO, logfile=sio)
  15.     return logger, sio
  16. 
  17. 
  18. # Task run functions can't be closures/lambdas, as they're pickled.
  19. def return_True(self, **kwargs):
  20.     return True
  21. 
  22. 
  23. def raise_exception(self, **kwargs):
  24.     raise Exception("%s error" % self.__class__)
  25. 
  26. 
  27. class IncrementCounterTask(task.Task):
  28.     name = "c.unittest.increment_counter_task"
  29.     count = 0
  30. 
  31.     def run(self, increment_by, **kwargs):
  32.         increment_by = increment_by or 1
  33.         self.__class__.count += increment_by
  34. 
  35. 
  36. class TestCeleryTasks(unittest.TestCase):
  37. 
  38.     def createTaskCls(self, cls_name, task_name=None):
  39.         attrs = {}
  40.         if task_name:
  41.             attrs["name"] = task_name
  42.         cls = type(cls_name, (task.Task, ), attrs)
  43.         cls.run = return_True
  44.         return cls
  45. 
  46.     def assertNextTaskDataEquals(self, consumer, task_id, task_name,
  47.             **kwargs):
  48.         next_task = consumer.fetch()
  49.         task_data = consumer.decoder(next_task.body)
  50.         self.assertEquals(task_data["celeryID"], task_id)
  51.         self.assertEquals(task_data["celeryTASK"], task_name)
  52.         for arg_name, arg_value in kwargs.items():
  53.             self.assertEquals(task_data.get(arg_name), arg_value)
  54. 
  55.     def test_raising_task(self):
  56.         rtask = self.createTaskCls("RaisingTask", "c.unittest.t.rtask")
  57.         rtask.run = raise_exception
  58.         sio = StringIO()
  59. 
  60.         taskinstance = rtask()
  61.         taskinstance(loglevel=logging.INFO, logfile=sio)
  62.         self.assertTrue(sio.getvalue().find("Task got exception") != -1)
  63. 
  64.     def test_incomplete_task_cls(self):
  65.         class IncompleteTask(task.Task):
  66.             name = "c.unittest.t.itask"
  67. 
  68.         self.assertRaises(NotImplementedError, IncompleteTask().run)
  69. 
  70.     def test_regular_task(self):
  71.         T1 = self.createTaskCls("T1", "c.unittest.t.t1")
  72.         self.assertTrue(isinstance(T1(), T1))
  73.         self.assertTrue(T1().run())
  74.         self.assertTrue(callable(T1()),
  75.                 "Task class is callable()")
  76.         self.assertTrue(T1()(),
  77.                 "Task class runs run() when called")
  78. 
  79.         # task without name raises NotImplementedError
  80.         T2 = self.createTaskCls("T2")
  81.         self.assertRaises(NotImplementedError, T2)
  82. 
  83.         registry.tasks.register(T1)
  84.         t1 = T1()
  85.         consumer = t1.get_consumer()
  86.         consumer.discard_all()
  87.         self.assertTrue(consumer.fetch() is None)
  88. 
  89.         # Without arguments.
  90.         tid = t1.delay()
  91.         self.assertNextTaskDataEquals(consumer, tid, t1.name)
  92. 
  93.         # With arguments.
  94.         tid2 = task.delay_task(t1.name, name="George Constanza")
  95.         self.assertNextTaskDataEquals(consumer, tid2, t1.name,
  96.                 name="George Constanza")
  97. 
  98.         self.assertRaises(registry.tasks.NotRegistered, task.delay_task,
  99.                 "some.task.that.should.never.exist.X.X.X.X.X")
 100. 
 101.         # Discarding all tasks.
 102.         task.discard_all()
 103.         tid3 = task.delay_task(t1.name)
 104.         self.assertEquals(task.discard_all(), 1)
 105.         self.assertTrue(consumer.fetch() is None)
 106. 
 107.         self.assertFalse(task.is_done(tid))
 108.         task.mark_as_done(tid, result=None)
 109.         self.assertTrue(task.is_done(tid))
 110. 
 111. 
 112.         publisher = t1.get_publisher()
 113.         self.assertTrue(isinstance(publisher, messaging.TaskPublisher))
 114. 
 115.     def test_taskmeta_cache(self):
 116.         # TODO Needs to test task meta without TASK_META_USE_DB.
 117.         tid = str(uuid.uuid4())
 118.         ckey = task.gen_task_done_cache_key(tid)
 119.         self.assertTrue(ckey.rfind(tid) != -1)
 120. 
 121. 
 122. class TestTaskSet(unittest.TestCase):
 123. 
 124.     def test_counter_taskset(self):
 125.         ts = task.TaskSet(IncrementCounterTask, [
 126.             {},
 127.             {"increment_by": 2},
 128.             {"increment_by": 3},
 129.             {"increment_by": 4},
 130.             {"increment_by": 5},
 131.             {"increment_by": 6},
 132.             {"increment_by": 7},
 133.             {"increment_by": 8},
 134.             {"increment_by": 9},
 135.         ])
 136.         self.assertEquals(ts.task_name, IncrementCounterTask.name)
 137.         self.assertEquals(ts.total, 9)
 138. 
 139.         taskset_id, subtask_ids = ts.run()
 140. 
 141.         consumer = IncrementCounterTask().get_consumer()
 142.         for subtask_id in subtask_ids:
 143.             m = consumer.decoder(consumer.fetch().body)
 144.             self.assertEquals(m.get("celeryTASKSET"), taskset_id)
 145.             self.assertEquals(m.get("celeryTASK"), IncrementCounterTask.name)
 146.             self.assertEquals(m.get("celeryID"), subtask_id)
 147.             IncrementCounterTask().run(increment_by=m.get("increment_by"))
 148.         self.assertEquals(IncrementCounterTask.count, sum(xrange(1, 10)))