Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# type: () -> None
with patch.object(GraphDatabase, 'driver') as mock_driver:
mock_session = MagicMock()
mock_driver.return_value.session.return_value = mock_session
mock_transaction = MagicMock()
mock_session.begin_transaction.return_value = mock_transaction
mock_run = MagicMock()
mock_transaction.run = mock_run
mock_commit = MagicMock()
mock_transaction.commit = mock_commit
publisher = Neo4jCsvPublisher()
conf = ConfigFactory.from_dict(
{neo4j_csv_publisher.NEO4J_END_POINT_KEY: 'dummy://999.999.999.999:7687/',
neo4j_csv_publisher.NODE_FILES_DIR: '{}/nodes'.format(self._resource_path),
neo4j_csv_publisher.RELATION_FILES_DIR: '{}/relations'.format(self._resource_path),
neo4j_csv_publisher.NEO4J_USER: 'neo4j_user',
neo4j_csv_publisher.NEO4J_PASSWORD: 'neo4j_password',
neo4j_csv_publisher.JOB_PUBLISH_TAG: '{}'.format(uuid.uuid4())}
)
publisher.init(conf)
publisher.publish()
self.assertEqual(mock_run.call_count, 6)
# 2 node files, 1 relation file
self.assertEqual(mock_commit.call_count, 1)
def test_transform_without_model_class_conf(self):
# type: () -> None
"""
Test model_class conf is required
"""
config_dict = {'transformer.elasticsearch.index': self.elasticsearch_index,
'transformer.elasticsearch.doc_type': self.elasticsearch_type}
transformer = ElasticsearchDocumentTransformer()
with self.assertRaises(Exception) as context:
transformer.init(conf=Scoped.get_scoped_conf(conf=ConfigFactory.from_dict(config_dict),
scope=transformer.get_scope()))
self.assertTrue("User needs to provide the ElasticsearchDocument model class"
in context.exception)
def setUp(self):
# type: () -> None
self.elasticsearch_index = 'test_es_index'
self.elasticsearch_type = 'test_es_type'
config_dict = {'transformer.elasticsearch.index': self.elasticsearch_index,
'transformer.elasticsearch.doc_type': self.elasticsearch_type,
'transformer.elasticsearch.model_class':
'databuilder.models.table_elasticsearch_document.TableESDocument'}
self.conf = ConfigFactory.from_dict(config_dict)
def os_system(cmd, raise_on_error=True):
p = delegator.run(cmd)
if raise_on_error and p.return_code != 0:
puts(p.err)
raise Exception("Command failed: {}".format(cmd))
def update_spinner_txt(spinner, txt):
spinner.text = txt
REMOTES = ['lyft', 'apache', 'hughhhh']
with yaspin(text="Loading", color="yellow") as spinner:
conf = ConfigFactory.parse_file('scripts/build.conf')
target = conf.get('target')
try:
deploy_branch = args.all[0]
commit_msg = args.all[1] if len(args.all) > 1 else '🍒'
except IndexError:
puts(colored.red('You must enter a branch name e.g. `python scripts/git_build.py {branch_name}`'))
os._exit(1)
update_spinner_txt(spinner, 'Checking out changes')
os_system('git submodule update --checkout')
os.chdir('upstream')
for remote in REMOTES:
LOGGER = logging.getLogger(__name__)
class Neo4jExtractor(Extractor):
"""
Extractor to fetch records from Neo4j Graph database
"""
CYPHER_QUERY_CONFIG_KEY = 'cypher_query'
GRAPH_URL_CONFIG_KEY = 'graph_url'
MODEL_CLASS_CONFIG_KEY = 'model_class'
NEO4J_AUTH_USER = 'neo4j_auth_user'
NEO4J_AUTH_PW = 'neo4j_auth_pw'
NEO4J_MAX_CONN_LIFE_TIME_SEC = 'neo4j_max_conn_life_time_sec'
DEFAULT_CONFIG = ConfigFactory.from_dict({NEO4J_MAX_CONN_LIFE_TIME_SEC: 50, })
def init(self, conf):
# type: (ConfigTree) -> None
"""
Establish connections and import data model class if provided
:param conf:
"""
self.conf = conf.with_fallback(Neo4jExtractor.DEFAULT_CONFIG)
self.graph_url = conf.get_string(Neo4jExtractor.GRAPH_URL_CONFIG_KEY)
self.cypher_query = conf.get_string(Neo4jExtractor.CYPHER_QUERY_CONFIG_KEY)
self.driver = self._get_driver()
self._extract_iter = None # type: Union[None, Iterator]
model_class = conf.get(Neo4jExtractor.MODEL_CLASS_CONFIG_KEY, None)
if model_class:
@classmethod
def from_file(cls, path):
config_tree = ConfigFactory.parse_file(path)
return cls(config_tree)
@classmethod
def from_file(cls, path, fmt='hocon'):
if fmt == 'hocon':
config_tree = ConfigFactory.parse_file(path)
elif fmt == 'json':
with open(path, 'r') as f:
d = json.load(f)
config_tree = ConfigFactory.from_dict(d)
else:
raise ValueError('Invalid format: {}'.format(fmt))
return cls(config_tree)