Skip to content

Commit

Permalink
fix case insensitive _ContentTypesItem lookup
Browse files Browse the repository at this point in the history
ISO/IEC 29500-2, section 10.1.2.3 and 10.1.2.4 specifies that lookup of
content type from partname (override) or extension (default) should be
case insensitive. This patch adds that behavior.

Incidentals:
* create separate test class fixture for __getitem__ method of
  _ContentTypesItem.
* change __defaults and __overrides local variables to single leading
  underscore
* initialize _defaults and _overrides to empty dict instead of None
* remove guard check for _defaults or _overrides is None on __getitem__
* whitespace adjustments
  • Loading branch information
Steve Canny committed Jul 13, 2013
1 parent 3168163 commit 643842b
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 50 deletions.
59 changes: 26 additions & 33 deletions pptx/packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,39 +551,33 @@ class _ContentTypesItem(object):
"""
def __init__(self):
super(_ContentTypesItem, self).__init__()
self.__defaults = None
self.__overrides = None
self._defaults = {}
self._overrides = {}

def __getitem__(self, partname):
"""
Return the content type for the part with *partname*.
"""
# raise exception if called before load()
if self.__defaults is None or self.__overrides is None:
tmpl = "lookup _ContentTypesItem['%s'] attempted before load"
raise ValueError(tmpl % partname)
# first look for an explicit content type
if partname in self.__overrides:
return self.__overrides[partname]
# if not, look for a default based on the extension
# look for explicit partname match in overrides (case-insensitive)
for override_partname, content_type in self._overrides.items():
if override_partname.lower() == partname.lower():
return content_type
# look for case-insensitive match on extension in default element
ext = os.path.splitext(partname)[1] # get extension of partname
# with leading dot trimmed off
ext = ext[1:] if ext.startswith('.') else ext
if ext in self.__defaults:
return self.__defaults[ext]
for extension, content_type in self._defaults.items():
if extension.lower() == ext.lower():
return content_type
# if neither of those work, raise an exception
tmpl = "no content type for part '%s' in [Content_Types].xml"
raise LookupError(tmpl % partname)

def __len__(self):
"""
Return sum count of Default and Override elements.
"""
count = len(self.__defaults) if self.__defaults is not None else 0
count += len(self.__overrides) if self.__overrides is not None else 0
return count
return len(self._defaults) + len(self._overrides)

def compose(self, parts):
"""
Expand All @@ -592,19 +586,19 @@ def compose(self, parts):
# extensions in this dict include leading '.'
def_cts = pptx.spec.default_content_types
# initialize working dictionaries for defaults and overrides
self.__defaults = dict((ext[1:], def_cts[ext])
for ext in ('.rels', '.xml'))
self.__overrides = {}
self._defaults = dict((ext[1:], def_cts[ext])
for ext in ('.rels', '.xml'))
self._overrides = {}
# compose appropriate element for each part
for part in parts:
ext = os.path.splitext(part.partname)[1]
# if extension is '.xml', assume an override. There might be a
# fancier way to do this, otherwise I don't know what 'xml'
# Default entry is for.
if ext == '.xml':
self.__overrides[part.partname] = part.content_type
self._overrides[part.partname] = part.content_type
elif ext in def_cts:
self.__defaults[ext[1:]] = def_cts[ext]
self._defaults[ext[1:]] = def_cts[ext]
else:
tmpl = "extension '%s' not found in default_content_types"
raise LookupError(tmpl % (ext))
Expand All @@ -614,32 +608,31 @@ def compose(self, parts):
def element(self):
nsmap = {None: pptx.spec.nsmap['ct']}
element = etree.Element(qtag('ct:Types'), nsmap=nsmap)
if self.__defaults:
for ext in sorted(self.__defaults.keys()):
if self._defaults:
for ext in sorted(self._defaults.keys()):
subelm = etree.SubElement(element, qtag('ct:Default'))
subelm.set('Extension', ext)
subelm.set('ContentType', self.__defaults[ext])
if self.__overrides:
for partname in sorted(self.__overrides.keys()):
subelm.set('ContentType', self._defaults[ext])
if self._overrides:
for partname in sorted(self._overrides.keys()):
subelm = etree.SubElement(element, qtag('ct:Override'))
subelm.set('PartName', partname)
subelm.set('ContentType', self.__overrides[partname])
subelm.set('ContentType', self._overrides[partname])
return element

def load(self, fs):
"""
Retrieve [Content_Types].xml from specified file system and load it.
Returns a reference to this _ContentTypesItem instance to allow
generative call, e.g. ``cti = _ContentTypesItem().load(fs)``.
"""
element = fs.getelement('/[Content_Types].xml')
defaults = element.findall(qtag('ct:Default'))
overrides = element.findall(qtag('ct:Override'))
self.__defaults = dict((d.get('Extension'), d.get('ContentType'))
for d in defaults)
self.__overrides = dict((o.get('PartName'), o.get('ContentType'))
for o in overrides)
self._defaults = dict((d.get('Extension'), d.get('ContentType'))
for d in defaults)
self._overrides = dict((o.get('PartName'), o.get('ContentType'))
for o in overrides)
return self


Expand Down
56 changes: 39 additions & 17 deletions test/test_packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import os

from collections import namedtuple
from hamcrest import assert_that, is_
from hamcrest import assert_that, equal_to, is_
from lxml import etree
from mock import Mock
from StringIO import StringIO
Expand Down Expand Up @@ -117,7 +117,44 @@ def test_getelement_raises_on_binary(self):
fs.getelement('/docProps/thumbnail.jpeg')


class Test_ContentTypesItem(TestCase):
class Test_ContentTypesItem__getitem__(TestCase):
"""Test dictionary-style access of content type using partname as key"""
def setUp(self):
self.cti = _ContentTypesItem()

def test_it_finds_default_case_insensitive(self):
"""_ContentTypesItem[partname] finds default case insensitive"""
# setup ------------------------
partname = '/ppt/media/image1.JPG'
content_type = 'image/jpeg'
self.cti._defaults = {'jpg': content_type}
# exercise ---------------------
val = self.cti[partname]
# verify -----------------------
assert_that(val, is_(equal_to(content_type)))

def test_it_finds_override_case_insensitive(self):
"""_ContentTypesItem[partname] finds override case insensitive"""
# setup ------------------------
partname = '/foo/bar.xml'
case_mangled_partname = '/FoO/bAr.XML'
content_type = 'application/vnd.content_type'
self.cti._overrides = {
partname: content_type
}
# exercise ---------------------
val = self.cti[case_mangled_partname]
# verify -----------------------
assert_that(val, is_(equal_to(content_type)))

def test_getitem_raises_on_bad_partname(self):
"""_ContentTypesItem[partname] raises on bad partname"""
# verify -----------------------
with self.assertRaises(LookupError):
self.cti['!blat/rhumba.1x&']


class Test_ContentTypesItem_compose(TestCase):
"""Test _ContentTypesItem"""
def setUp(self):
self.cti = _ContentTypesItem()
Expand Down Expand Up @@ -161,21 +198,6 @@ def test_element_correct_length(self):
# verify ----------------------
self.assertLength(self.cti.element, 24)

def test_getitem_raises_before_load(self):
"""_ContentTypesItem[partname] raises before load"""
# verify ----------------------
with self.assertRaises(ValueError):
self.cti['/ppt/presentation.xml']

def test_getitem_raises_on_bad_partname(self):
"""_ContentTypesItem[partname] raises on bad partname"""
# setup ------------------------
fs = FileSystem(zip_pkg_path)
self.cti.load(fs)
# verify ----------------------
with self.assertRaises(LookupError):
self.cti['!blat/rhumba.1x&']

def test_load_spotcheck(self):
"""_ContentTypesItem can load itself from a filesystem instance"""
# setup ------------------------
Expand Down

0 comments on commit 643842b

Please sign in to comment.