From 643842bb984b4a97bc2da0c5e11acc00a8f2c8ba Mon Sep 17 00:00:00 2001 From: Steve Canny Date: Fri, 12 Jul 2013 22:10:16 -0700 Subject: [PATCH] fix case insensitive _ContentTypesItem lookup ISO/IEC 29500-2, section 10.1.2.3 and 10.1.2.4 specifies that lookup of content type from partname (override) or extension (default) should be case insensitive. This patch adds that behavior. Incidentals: * create separate test class fixture for __getitem__ method of _ContentTypesItem. * change __defaults and __overrides local variables to single leading underscore * initialize _defaults and _overrides to empty dict instead of None * remove guard check for _defaults or _overrides is None on __getitem__ * whitespace adjustments --- pptx/packaging.py | 59 +++++++++++++++++++----------------------- test/test_packaging.py | 56 +++++++++++++++++++++++++++------------ 2 files changed, 65 insertions(+), 50 deletions(-) diff --git a/pptx/packaging.py b/pptx/packaging.py index f78276d32..81c12d32b 100644 --- a/pptx/packaging.py +++ b/pptx/packaging.py @@ -551,27 +551,24 @@ class _ContentTypesItem(object): """ def __init__(self): super(_ContentTypesItem, self).__init__() - self.__defaults = None - self.__overrides = None + self._defaults = {} + self._overrides = {} def __getitem__(self, partname): """ Return the content type for the part with *partname*. - """ - # raise exception if called before load() - if self.__defaults is None or self.__overrides is None: - tmpl = "lookup _ContentTypesItem['%s'] attempted before load" - raise ValueError(tmpl % partname) - # first look for an explicit content type - if partname in self.__overrides: - return self.__overrides[partname] - # if not, look for a default based on the extension + # look for explicit partname match in overrides (case-insensitive) + for override_partname, content_type in self._overrides.items(): + if override_partname.lower() == partname.lower(): + return content_type + # look for case-insensitive match on extension in default element ext = os.path.splitext(partname)[1] # get extension of partname # with leading dot trimmed off ext = ext[1:] if ext.startswith('.') else ext - if ext in self.__defaults: - return self.__defaults[ext] + for extension, content_type in self._defaults.items(): + if extension.lower() == ext.lower(): + return content_type # if neither of those work, raise an exception tmpl = "no content type for part '%s' in [Content_Types].xml" raise LookupError(tmpl % partname) @@ -579,11 +576,8 @@ def __getitem__(self, partname): def __len__(self): """ Return sum count of Default and Override elements. - """ - count = len(self.__defaults) if self.__defaults is not None else 0 - count += len(self.__overrides) if self.__overrides is not None else 0 - return count + return len(self._defaults) + len(self._overrides) def compose(self, parts): """ @@ -592,9 +586,9 @@ def compose(self, parts): # extensions in this dict include leading '.' def_cts = pptx.spec.default_content_types # initialize working dictionaries for defaults and overrides - self.__defaults = dict((ext[1:], def_cts[ext]) - for ext in ('.rels', '.xml')) - self.__overrides = {} + self._defaults = dict((ext[1:], def_cts[ext]) + for ext in ('.rels', '.xml')) + self._overrides = {} # compose appropriate element for each part for part in parts: ext = os.path.splitext(part.partname)[1] @@ -602,9 +596,9 @@ def compose(self, parts): # fancier way to do this, otherwise I don't know what 'xml' # Default entry is for. if ext == '.xml': - self.__overrides[part.partname] = part.content_type + self._overrides[part.partname] = part.content_type elif ext in def_cts: - self.__defaults[ext[1:]] = def_cts[ext] + self._defaults[ext[1:]] = def_cts[ext] else: tmpl = "extension '%s' not found in default_content_types" raise LookupError(tmpl % (ext)) @@ -614,16 +608,16 @@ def compose(self, parts): def element(self): nsmap = {None: pptx.spec.nsmap['ct']} element = etree.Element(qtag('ct:Types'), nsmap=nsmap) - if self.__defaults: - for ext in sorted(self.__defaults.keys()): + if self._defaults: + for ext in sorted(self._defaults.keys()): subelm = etree.SubElement(element, qtag('ct:Default')) subelm.set('Extension', ext) - subelm.set('ContentType', self.__defaults[ext]) - if self.__overrides: - for partname in sorted(self.__overrides.keys()): + subelm.set('ContentType', self._defaults[ext]) + if self._overrides: + for partname in sorted(self._overrides.keys()): subelm = etree.SubElement(element, qtag('ct:Override')) subelm.set('PartName', partname) - subelm.set('ContentType', self.__overrides[partname]) + subelm.set('ContentType', self._overrides[partname]) return element def load(self, fs): @@ -631,15 +625,14 @@ def load(self, fs): Retrieve [Content_Types].xml from specified file system and load it. Returns a reference to this _ContentTypesItem instance to allow generative call, e.g. ``cti = _ContentTypesItem().load(fs)``. - """ element = fs.getelement('/[Content_Types].xml') defaults = element.findall(qtag('ct:Default')) overrides = element.findall(qtag('ct:Override')) - self.__defaults = dict((d.get('Extension'), d.get('ContentType')) - for d in defaults) - self.__overrides = dict((o.get('PartName'), o.get('ContentType')) - for o in overrides) + self._defaults = dict((d.get('Extension'), d.get('ContentType')) + for d in defaults) + self._overrides = dict((o.get('PartName'), o.get('ContentType')) + for o in overrides) return self diff --git a/test/test_packaging.py b/test/test_packaging.py index 09db41a8f..c8778d01f 100644 --- a/test/test_packaging.py +++ b/test/test_packaging.py @@ -12,7 +12,7 @@ import os from collections import namedtuple -from hamcrest import assert_that, is_ +from hamcrest import assert_that, equal_to, is_ from lxml import etree from mock import Mock from StringIO import StringIO @@ -117,7 +117,44 @@ def test_getelement_raises_on_binary(self): fs.getelement('/docProps/thumbnail.jpeg') -class Test_ContentTypesItem(TestCase): +class Test_ContentTypesItem__getitem__(TestCase): + """Test dictionary-style access of content type using partname as key""" + def setUp(self): + self.cti = _ContentTypesItem() + + def test_it_finds_default_case_insensitive(self): + """_ContentTypesItem[partname] finds default case insensitive""" + # setup ------------------------ + partname = '/ppt/media/image1.JPG' + content_type = 'image/jpeg' + self.cti._defaults = {'jpg': content_type} + # exercise --------------------- + val = self.cti[partname] + # verify ----------------------- + assert_that(val, is_(equal_to(content_type))) + + def test_it_finds_override_case_insensitive(self): + """_ContentTypesItem[partname] finds override case insensitive""" + # setup ------------------------ + partname = '/foo/bar.xml' + case_mangled_partname = '/FoO/bAr.XML' + content_type = 'application/vnd.content_type' + self.cti._overrides = { + partname: content_type + } + # exercise --------------------- + val = self.cti[case_mangled_partname] + # verify ----------------------- + assert_that(val, is_(equal_to(content_type))) + + def test_getitem_raises_on_bad_partname(self): + """_ContentTypesItem[partname] raises on bad partname""" + # verify ----------------------- + with self.assertRaises(LookupError): + self.cti['!blat/rhumba.1x&'] + + +class Test_ContentTypesItem_compose(TestCase): """Test _ContentTypesItem""" def setUp(self): self.cti = _ContentTypesItem() @@ -161,21 +198,6 @@ def test_element_correct_length(self): # verify ---------------------- self.assertLength(self.cti.element, 24) - def test_getitem_raises_before_load(self): - """_ContentTypesItem[partname] raises before load""" - # verify ---------------------- - with self.assertRaises(ValueError): - self.cti['/ppt/presentation.xml'] - - def test_getitem_raises_on_bad_partname(self): - """_ContentTypesItem[partname] raises on bad partname""" - # setup ------------------------ - fs = FileSystem(zip_pkg_path) - self.cti.load(fs) - # verify ---------------------- - with self.assertRaises(LookupError): - self.cti['!blat/rhumba.1x&'] - def test_load_spotcheck(self): """_ContentTypesItem can load itself from a filesystem instance""" # setup ------------------------