fixing routing.py bug, if time is 24:00
[pywienerlinien] / gotovienna / routing.py
index 98d0573..d82fb3f 100644 (file)
@@ -1,14 +1,15 @@
 #!/usr/bin/env python
 # -*- coding: UTF-8 -*-
 
-from BeautifulSoup import BeautifulSoup, NavigableString
-from urllib2 import urlopen
+from gotovienna.BeautifulSoup import BeautifulSoup, NavigableString
+#from urllib2 import urlopen
+from UrlOpener import urlopen
 from urllib import urlencode
-from datetime import datetime, time
+from datetime import datetime, time, timedelta
 from textwrap import wrap
-import argparse
 import sys
 import os.path
+import re
 
 from gotovienna import defaults
 
@@ -36,7 +37,7 @@ def extract_city(station):
         return station.split(',')[-1].strip()
     else:
         return 'Wien'
-        
+
 def extract_station(station):
     """ Remove city from string
     
@@ -47,7 +48,7 @@ def extract_station(station):
         return station[:station.rindex(',')].strip()
     else:
         return station
-    
+
 def split_station(station):
     """ >>> split_station('Karlsplatz, Wien')
     ('Karlsplatz', 'Wien')
@@ -59,6 +60,37 @@ def split_station(station):
     else:
         return (station, 'Wien')
 
+def guess_location_type(location):
+    """Guess type (stop, address, poi) of a location
+
+    >>> guess_location_type('pilgramgasse')
+    'stop'
+
+    >>> guess_location_type('karlsplatz 14')
+    'address'
+
+    >>> guess_location_type('reumannplatz 12/34')
+    'address'
+    """
+    parts = location.split()
+    first_part = parts[0]
+    last_part = parts[-1]
+
+    # Assume all single-word locations are stops
+    if len(parts) == 1:
+        return 'stop'
+
+    # If the last part is numeric, assume address
+    if last_part.isdigit() and len(parts) > 1:
+        return 'address'
+
+    # Addresses with door number (e.g. "12/34")
+    if all(x.isdigit() or x == '/' for x in last_part):
+        return 'address'
+
+    # Sane default - assume it's a stop/station name
+    return 'stop'
+
 def search(origin_tuple, destination_tuple, dtime=None):
     """ build route request
     returns html result (as urllib response)
@@ -68,13 +100,21 @@ def search(origin_tuple, destination_tuple, dtime=None):
 
     origin, origin_type = origin_tuple
     origin, origin_city = split_station(origin)
-    
+
     destination, destination_type = destination_tuple
     destination, destination_city = split_station(destination)
 
 
-    if not origin_type in POSITION_TYPES or\
-        not destination_type in POSITION_TYPES:
+    if origin_type is None:
+        origin_type = guess_location_type(origin)
+        print 'Guessed origin type:', origin_type
+
+    if destination_type is None:
+        destination_type = guess_location_type(destination)
+        print 'Guessed destination type:', destination_type
+
+    if (origin_type not in POSITION_TYPES or
+            destination_type not in POSITION_TYPES):
         raise ParserError('Invalid position type')
 
     post = defaults.search_post
@@ -88,13 +128,7 @@ def search(origin_tuple, destination_tuple, dtime=None):
     post['place_destination'] = destination_city
     params = urlencode(post)
     url = '%s?%s' % (defaults.action, params)
-
-    try:
-        f = open(DEBUGLOG, 'a')
-        f.write(url + '\n')
-        f.close()
-    except:
-        print 'Unable to write to DEBUGLOG: %s' % DEBUGLOG
+    #print url
 
     return urlopen(url)
 
@@ -115,28 +149,34 @@ class sParser:
 
         return PageType.UNKNOWN
 
+    state = property(check_page)
+
     def get_correction(self):
         names_origin = self.soup.find('select', {'id': 'nameList_origin'})
         names_destination = self.soup.find('select', {'id': 'nameList_destination'})
         places_origin = self.soup.find('select', {'id': 'placeList_origin'})
         places_destination = self.soup.find('select', {'id': 'placeList_destination'})
-        
 
-        if names_origin or names_destination or places_origin or places_destination:
+
+        if any([names_origin, names_destination, places_origin, places_destination]):
             dict = {}
-            
+
             if names_origin:
-                dict['origin'] = map(lambda x: x.text, names_origin.findAll('option'))
+                dict['origin'] = map(lambda x: x.text,
+                                     names_origin.findAll('option'))
             if names_destination:
-                dict['destination'] = map(lambda x: x.text, names_destination.findAll('option'))
-                
+                dict['destination'] = map(lambda x: x.text,
+                                          names_destination.findAll('option'))
+
             if places_origin:
-                dict['place_origin'] = map(lambda x: x.text, names_origin.findAll('option'))
+                dict['place_origin'] = map(lambda x: x.text,
+                                           names_origin.findAll('option'))
             if names_destination:
-                dict['place_destination'] = map(lambda x: x.text, names_destination.findAll('option'))
-    
+                dict['place_destination'] = map(lambda x: x.text,
+                                                names_destination.findAll('option'))
+
             return dict
-        
+
         else:
             raise ParserError('Unable to parse html')
 
@@ -185,24 +225,57 @@ class rParser:
             return None
 
     @classmethod
-    def get_time(cls, x):
+    def get_datetime(cls, x):
         y = rParser.get_tdtext(x, 'col_time')
         if y:
             if (y.find("-") > 0):
-                return map(lambda z: time(*map(int, z.split(':'))), y.split('-'))
+                # overview mode
+                times = map(lambda z: time(*map(int, z.split(':'))), y.split('-'))
+                d = rParser.get_date(x)
+                from_dtime = datetime.combine(d, times[0])
+                if times[0] > times[1]:
+                    # dateline crossing
+                    to_dtime = datetime.combine(d + timedelta(1), times[1])
+                else:
+                    to_dtime = datetime.combine(d, times[1])
+
+                return [from_dtime, to_dtime]
+
             else:
-                # FIXME Error if date in line (dateLineCross)
-                return map(lambda z: time(*map(int, z.split(':'))), wrap(y, 5))
-        else:
-            return []
+                dtregex = {'date' : '\d\d\.\d\d',
+                           'time': '\d\d:\d\d'}
+
+                regex = "\s*(?P<date1>{date})?\s*(?P<time1>{time})\s*(?P<date2>{date})?\s*(?P<time2>{time})\s*".format(**dtregex)
+                ma = re.match(regex, y)
+
+                if not ma:
+                    return []
+
+                gr = ma.groupdict()
+
+                def extract_datetime(gr, n):
+                    if 'date%d' % n in gr and gr['date%d' % n]:
+                        if gr['time%d' % n] == '24:00':
+                            gr['time%d' % n] = '0:00'
+                        from_dtime = datetime.strptime(str(datetime.today().year) + gr['date%d' % n] + gr['time%d' % n], '%Y%d.%m.%H:%M')
+                    else:
+                        d = datetime.today().date()
+                        # Strange times possible at wienerlinien
+                        if gr['time%d' % n] == '24:00':
+                            gr['time%d' % n] = '0:00'
+                            d += timedelta(days=1)
+                        t = datetime.strptime(gr['time%d' % n], '%H:%M').time()
+                        
+                        return datetime.combine(d, t)
+
+                # detail mode
+                from_dtime = extract_datetime(gr, 1)
+                to_dtime = extract_datetime(gr, 2)
+
+                return [from_dtime, to_dtime]
 
-    @classmethod
-    def get_duration(cls, x):
-        y = rParser.get_tdtext(x, 'col_duration')
-        if y:
-            return time(*map(int, y.split(":")))
         else:
-            return None
+            return []
 
     def __iter__(self):
         for detail in self.details():
@@ -212,7 +285,7 @@ class rParser:
         tours = self.soup.findAll('div', {'class': 'data_table tourdetail'})
 
         trips = map(lambda x: map(lambda y: {
-                        'time': rParser.get_time(y),
+                        'timespan': rParser.get_datetime(y),
                         'station': map(lambda z: z[2:].strip(),
                                        filter(lambda x: type(x) == NavigableString, y.find('td', {'class': 'col_station'}).contents)), # filter non NaviStrings
                         'info': map(lambda x: x.strip(),
@@ -247,9 +320,7 @@ class rParser:
             rows = table.findAll('tr')[1:] # cut off headline
 
             overview = map(lambda x: {
-                               'date': rParser.get_date(x),
-                               'time': rParser.get_time(x),
-                               'duration': rParser.get_duration(x), # grab duration
+                               'timespan': rParser.get_datetime(x),
                                'change': rParser.get_change(x),
                                'price': rParser.get_price(x),
                            },