from numpy import linspace, random, zeros, arange, cumprod import time # ETS imports (non-chaco) from enthought.enable.component_editor import ComponentEditor from enthought.traits.api import HasTraits, Instance, Int, List, Str, Enum, \ on_trait_change, Any, DelegatesTo from enthought.traits.ui.api import Item, View, HSplit, VGroup, EnumEditor # Chaco imports from enthought.chaco.api import ArrayPlotData, Plot, PlotAxis, \ ScatterInspectorOverlay from enthought.chaco.scales.api import CalendarScaleSystem from enthought.chaco.scales_tick_generator import ScalesTickGenerator from enthought.chaco.example_support import COLOR_PALETTE from enthought.chaco.tools.api import PanTool, ZoomTool, RangeSelection, \ RangeSelectionOverlay, LegendTool def create_dates(numpoints, units="days"): """ Returns **numpoints** number of dates that evenly bracket the current date and time. **units** should be one of "weeks", "days", "hours" "minutes", or "seconds". """ units_map = { "weeks" : 7*24*3600, "days" : 24*3600, "hours" : 3600, "minutes" : 60, "seconds" : 1 } now = time.time() dt = units_map[units] dates = linspace(now, now+numpoints*dt, numpoints) return dates class PlotApp(HasTraits): plotdata = Instance(ArrayPlotData) numpoints = Int(300) symbols = List() sym1 = Enum(values="symbols") sym2 = Enum(values="symbols") returns_plot = Instance(Plot) times_ds = Any() # arraydatasource for the time axis data corr_plot = Instance(Plot) corr_renderer = Any() traits_view = View( HSplit( Item('returns_plot', editor=ComponentEditor(), show_label = False), VGroup( VGroup( Item('sym1', width=-225), Item('sym2', width=-225), ), Item('corr_plot', editor=ComponentEditor(), show_label = False, width=-275), ), ), width=1024, height=500, resizable=True, title = "Correlations of returns") def __init__(self, *symbols, **kwtraits): super(PlotApp, self).__init__(symbols=list(symbols), **kwtraits) self._create_data(*symbols) self._create_returns_plot() self._create_corr_plot() if len(self.symbols) > 1: self.sym2 = self.symbols[1] return def _create_returns_plot(self): plot = Plot(self.plotdata) plot.legend.visible = True #FIXME: The legend move tool doesn't seem to quite work right now #plot.legend.tools.append(LegendTool(plot.legend)) plot.x_axis = None x_axis = PlotAxis(plot, orientation="bottom", tick_generator=ScalesTickGenerator(scale=CalendarScaleSystem())) plot.overlays.append(x_axis) plot.x_grid.tick_generator = x_axis.tick_generator for i, name in enumerate(self.plotdata.list_data()): if name == "times": continue renderer = plot.plot(("times", name), type="line", name=name, color=tuple(COLOR_PALETTE[i]))[0] # Tricky: need to set auto_handle_event on the RangeSelection # so that it passes left-clicks to the PanTool # FIXME: The range selection is still getting the initial left down renderer.tools.append(RangeSelection(renderer, left_button_selects = False, auto_handle_event = False)) plot.tools.append(PanTool(plot, drag_button="left", constrain=True, constrain_direction="x")) plot.overlays.append(ZoomTool(plot, tool_mode="range", max_zoom_out=1.0)) # Attach the range selection to the last renderer; any one will do self._range_selection_overlay = RangeSelectionOverlay(renderer, metadata_name="selections") renderer.overlays.append(self._range_selection_overlay) # Grab a reference to the Time axis datasource and add a listener to its # selections metadata self.times_ds = renderer.index self.times_ds.on_trait_change(self._selections_changed, 'metadata_changed') self.returns_plot = plot def _create_corr_plot(self): plot = Plot(self.plotdata, padding=0) plot.padding_left = 25 plot.padding_bottom = 25 plot.tools.append(PanTool(plot)) plot.overlays.append(ZoomTool(plot)) self.corr_plot = plot def _create_data(self, *names): numpoints = self.numpoints plotdata = ArrayPlotData(times = create_dates(numpoints)) for name in names: plotdata.set_data(name, cumprod(random.lognormal(0.0, 0.04, size=numpoints))) self.plotdata = plotdata def _selections_changed(self, event): if self.corr_renderer is None: return if not isinstance(event, dict) or "selections" not in event: return corr_index = self.corr_renderer.index selections = event["selections"] if selections is None: corr_index.metadata.pop("selections", None) return else: low, high = selections data = self.times_ds.get_data() low_ndx = data.searchsorted(low) high_ndx = data.searchsorted(high) corr_index.metadata["selections"] = arange(low_ndx, high_ndx+1, 1, dtype=int) self.corr_plot.request_redraw() @on_trait_change("sym1,sym2") def _update_corr_symbols(self): plot = self.corr_plot if self.corr_renderer is not None: # Remove the old one plot.remove(self.corr_renderer) self.corr_renderer = None self.corr_renderer = plot.plot((self.sym1, self.sym2), type="scatter", color="blue")[0] self.corr_renderer.overlays.append(ScatterInspectorOverlay(self.corr_renderer, selection_color = "lightgreen")) plot.request_redraw() demo = PlotApp("AAPL", "GOOG", "MSFT") if __name__ == "__main__": demo.configure_traits()